from .. import pq
from .. import abc
from .. import postgres
-from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader
+from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader, Dumper
from .._struct import pack_len, unpack_len
from ..postgres import TEXT_OID
from .._typeinfo import CompositeInfo as CompositeInfo # exported here
return self._dump_sequence(obj, b"(", b")", b",")
-class TupleBinaryDumper(RecursiveDumper):
+class TupleBinaryDumper(Dumper):
format = pq.Format.BINARY
# Subclasses must set an info
def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
super().__init__(cls, context)
- nfields = len(self.info.field_types)
+
+ # Note: this class is not a RecursiveDumper because it would use the
+ # same Transformer of the context, which would confuse dump_sequence()
+ # in case the composite contains another composite. Make sure to use
+ # a separate Transformer instance instead.
+ self._tx = Transformer(context)
self._tx.set_dumper_types(self.info.field_types, self.format)
+
+ nfields = len(self.info.field_types)
self._formats = (PyFormat.from_pq(self.format),) * nfields
def dump(self, obj: Tuple[Any, ...]) -> bytearray:
"""
create schema if not exists testschema;
+ drop type if exists testcomp2 cascade;
drop type if exists testcomp cascade;
drop type if exists testschema.testcomp cascade;
create type testcomp as (foo text, bar int8, baz float8);
+ create type testcomp2 as (qux int8, quux testcomp);
create type testschema.testcomp as (foo text, bar int8, qux bool);
"""
)
assert rec[0] is True, rec[1]
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_recursive_composite(conn, fmt_in, testcomp):
+ info = CompositeInfo.fetch(conn, "testcomp")
+ info2 = CompositeInfo.fetch(conn, "testcomp2")
+
+ cur = conn.cursor()
+ register_composite(info, cur)
+ register_composite(info2, cur)
+ testcomp = info.python_type
+ testcomp2 = info2.python_type
+
+ obj = testcomp2(42, testcomp("foo", 1, None))
+ rec = cur.execute(
+ f"""
+ select row(42, row('foo', 1, NULL)::testcomp)::testcomp2 = %(obj){fmt_in.value},
+ %(obj){fmt_in.value}::text
+ """,
+ {"obj": obj},
+ ).fetchone()
+ assert rec[0] is True, rec[1]
+
+
@pytest.mark.parametrize("fmt_out", pq.Format)
def test_load_composite(conn, testcomp, fmt_out):
info = CompositeInfo.fetch(conn, "testcomp")
assert isinstance(res[0].baz, float)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_recursive_composite(conn, testcomp, fmt_out):
+ info = CompositeInfo.fetch(conn, "testcomp")
+ info2 = CompositeInfo.fetch(conn, "testcomp2")
+
+ register_composite(info, conn)
+ register_composite(info2, conn)
+
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute("select row(42, row('hello', 10, 20)::testcomp)::testcomp2")
+ res = cur.fetchone()[0]
+ assert res.qux == 42
+ assert res.quux.foo == "hello"
+ assert res.quux.bar == 10
+ assert res.quux.baz == 20.0
+ assert isinstance(res.quux.baz, float)
+
+ cur.execute("select array[row(42, row('hello', 10, 30)::testcomp)::testcomp2]")
+ res = cur.fetchone()[0]
+ assert len(res) == 1
+ assert res[0].qux == 42
+ assert res[0].quux.baz == 30.0
+ assert isinstance(res[0].quux.baz, float)
+
+
@pytest.mark.parametrize("fmt_out", pq.Format)
def test_load_composite_factory(conn, testcomp, fmt_out):
info = CompositeInfo.fetch(conn, "testcomp")