)
adapters.register_loader(info.oid, loader)
+ # If the factory is a type, register a dumper for it
+ if isinstance(factory, type):
+ dumper = type(
+ f"{info.name.title()}Dumper", (TupleDumper,), {"_oid": info.oid}
+ )
+ adapters.register_dumper(factory, dumper)
+ info.python_type = factory
+
def register_default_adapters(context: AdaptContext) -> None:
adapters = context.adapters
from psycopg.sql import Identifier
from psycopg.adapt import PyFormat as Format
from psycopg.postgres import types as builtins
-from psycopg.types.composite import CompositeInfo
+from psycopg.types.composite import CompositeInfo, TupleDumper
tests_str = [
self.foo, self.bar, self.baz = args
info.register(conn, factory=MyThing)
+ assert info.python_type is MyThing
cur = conn.cursor(binary=fmt_out)
res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
for oid in (info.oid, info.array_oid):
assert postgres.adapters._loaders[fmt].pop(oid)
+ for fmt in (Format.AUTO, Format.TEXT):
+ assert postgres.adapters._dumpers[fmt].pop(info.python_type)
+
+ assert info.python_type not in postgres.adapters._dumpers[Format.BINARY]
+
cur = conn.cursor()
info.register(cur)
for fmt in (pq.Format.TEXT, pq.Format.BINARY):
for oid in (info.oid, info.array_oid):
assert oid not in postgres.adapters._loaders[fmt]
assert oid in conn.adapters._loaders[fmt]
+
+
+def test_type_dumper_registered(conn, testcomp):
+ info = CompositeInfo.fetch(conn, "testcomp")
+ info.register(conn)
+ assert issubclass(info.python_type, tuple)
+ assert info.python_type.__name__ == "testcomp"
+ d = conn.adapters.get_dumper(info.python_type, "s")
+ assert issubclass(d, TupleDumper)
+ assert d is not TupleDumper
+
+ tc = info.python_type("foo", 42, 3.14)
+ cur = conn.execute("select pg_typeof(%s)", [tc])
+ assert cur.fetchone()[0] == "testcomp"
+
+
+def test_callable_dumper_not_registered(conn, testcomp):
+ info = CompositeInfo.fetch(conn, "testcomp")
+
+ def fac(*args):
+ return args + (args[-1],)
+
+ info.register(conn, factory=fac)
+ assert info.python_type is None
+
+ # but the loader is registered
+ cur = conn.execute("select '(foo,42,3.14)'::testcomp")
+ assert cur.fetchone()[0] == ("foo", 42, 3.14, 3.14)