]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Register composite dumper if the factory is a type
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 26 Aug 2021 15:24:26 +0000 (17:24 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 26 Aug 2021 18:46:41 +0000 (20:46 +0200)
Expose the type as info.python_type after registering.

psycopg/psycopg/_typeinfo.py
psycopg/psycopg/types/composite.py
tests/types/test_composite.py

index 17f572fda5e04bffd0f9795ca1006260d15d9c9c..0f7ff2026a5da937ddfb78da2b95b1898e47703d 100644 (file)
@@ -191,6 +191,8 @@ class CompositeInfo(TypeInfo):
         super().__init__(name, oid, array_oid)
         self.field_names = field_names
         self.field_types = field_types
+        # Will be set by register() if the `factory` is a type
+        self.python_type: Optional[type] = None
 
     def register(
         self,
index 7112384007430448e0b84dae8df74bd5120d5129..e1a84394662274bc94b6fd035e676e25034cd93d 100644 (file)
@@ -212,6 +212,14 @@ def register_composite(
     )
     adapters.register_loader(info.oid, loader)
 
+    # If the factory is a type, register a dumper for it
+    if isinstance(factory, type):
+        dumper = type(
+            f"{info.name.title()}Dumper", (TupleDumper,), {"_oid": info.oid}
+        )
+        adapters.register_dumper(factory, dumper)
+        info.python_type = factory
+
 
 def register_default_adapters(context: AdaptContext) -> None:
     adapters = context.adapters
index 749c7376b0d2b80dfb1ec7b2f9f777613f6aa19e..952959d93c9e885f8c9effe02ea190f5a8177901 100644 (file)
@@ -4,7 +4,7 @@ from psycopg import pq, postgres
 from psycopg.sql import Identifier
 from psycopg.adapt import PyFormat as Format
 from psycopg.postgres import types as builtins
-from psycopg.types.composite import CompositeInfo
+from psycopg.types.composite import CompositeInfo, TupleDumper
 
 
 tests_str = [
@@ -198,6 +198,7 @@ def test_load_composite_factory(conn, testcomp, fmt_out):
             self.foo, self.bar, self.baz = args
 
     info.register(conn, factory=MyThing)
+    assert info.python_type is MyThing
 
     cur = conn.cursor(binary=fmt_out)
     res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
@@ -220,6 +221,11 @@ def test_register_scope(conn, testcomp):
         for oid in (info.oid, info.array_oid):
             assert postgres.adapters._loaders[fmt].pop(oid)
 
+    for fmt in (Format.AUTO, Format.TEXT):
+        assert postgres.adapters._dumpers[fmt].pop(info.python_type)
+
+    assert info.python_type not in postgres.adapters._dumpers[Format.BINARY]
+
     cur = conn.cursor()
     info.register(cur)
     for fmt in (pq.Format.TEXT, pq.Format.BINARY):
@@ -233,3 +239,31 @@ def test_register_scope(conn, testcomp):
         for oid in (info.oid, info.array_oid):
             assert oid not in postgres.adapters._loaders[fmt]
             assert oid in conn.adapters._loaders[fmt]
+
+
+def test_type_dumper_registered(conn, testcomp):
+    info = CompositeInfo.fetch(conn, "testcomp")
+    info.register(conn)
+    assert issubclass(info.python_type, tuple)
+    assert info.python_type.__name__ == "testcomp"
+    d = conn.adapters.get_dumper(info.python_type, "s")
+    assert issubclass(d, TupleDumper)
+    assert d is not TupleDumper
+
+    tc = info.python_type("foo", 42, 3.14)
+    cur = conn.execute("select pg_typeof(%s)", [tc])
+    assert cur.fetchone()[0] == "testcomp"
+
+
+def test_callable_dumper_not_registered(conn, testcomp):
+    info = CompositeInfo.fetch(conn, "testcomp")
+
+    def fac(*args):
+        return args + (args[-1],)
+
+    info.register(conn, factory=fac)
+    assert info.python_type is None
+
+    # but the loader is registered
+    cur = conn.execute("select '(foo,42,3.14)'::testcomp")
+    assert cur.fetchone()[0] == ("foo", 42, 3.14, 3.14)