]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: don't reuse the same Transformer in composite dumper 550/head
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 16 Apr 2023 01:30:37 +0000 (03:30 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 16 Apr 2023 01:47:18 +0000 (03:47 +0200)
We need different dumpers because, in case a composite contains another
composite, we need to call `dump_sequence()` on different sequences, so
we row dumpers must be distinct.

Close #547

docs/news.rst
psycopg/psycopg/types/composite.py
tests/types/test_composite.py

index b76a973059c4322b5dbe0a55db00db9d3ea6d235..2798a39ac28c74a2914a40b0651274807da89333 100644 (file)
@@ -21,6 +21,7 @@ Psycopg 3.1.9 (unreleased)
   (:ticket:`#543`).
 - Fix loading ROW values with different types in the same query using the
   binary protocol (:ticket:`#545`).
+- Fix dumping recursive composite types (:ticket:`#547`).
 
 
 Current release
index 968ee6206d6383bf5a1800d135e424798ecdac08..40a1e176b161ff3b350956f16190d8ac8a97117a 100644 (file)
@@ -13,7 +13,7 @@ from typing import Sequence, Tuple, Type
 from .. import pq
 from .. import abc
 from .. import postgres
-from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader
+from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader, Dumper
 from .._struct import pack_len, unpack_len
 from ..postgres import TEXT_OID
 from .._typeinfo import CompositeInfo as CompositeInfo  # exported here
@@ -66,7 +66,7 @@ class TupleDumper(SequenceDumper):
         return self._dump_sequence(obj, b"(", b")", b",")
 
 
-class TupleBinaryDumper(RecursiveDumper):
+class TupleBinaryDumper(Dumper):
     format = pq.Format.BINARY
 
     # Subclasses must set an info
@@ -74,8 +74,15 @@ class TupleBinaryDumper(RecursiveDumper):
 
     def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
         super().__init__(cls, context)
-        nfields = len(self.info.field_types)
+
+        # Note: this class is not a RecursiveDumper because it would use the
+        # same Transformer of the context, which would confuse dump_sequence()
+        # in case the composite contains another composite. Make sure to use
+        # a separate Transformer instance instead.
+        self._tx = Transformer(context)
         self._tx.set_dumper_types(self.info.field_types, self.format)
+
+        nfields = len(self.info.field_types)
         self._formats = (PyFormat.from_pq(self.format),) * nfields
 
     def dump(self, obj: Tuple[Any, ...]) -> bytearray:
index ad7db6e12fbc7176d8dadedf506165eff276ad1d..2a2a3a87806bd07276451df401d1219a3d14b022 100644 (file)
@@ -141,10 +141,12 @@ def testcomp(svcconn):
         """
         create schema if not exists testschema;
 
+        drop type if exists testcomp2 cascade;
         drop type if exists testcomp cascade;
         drop type if exists testschema.testcomp cascade;
 
         create type testcomp as (foo text, bar int8, baz float8);
+        create type testcomp2 as (qux int8, quux testcomp);
         create type testschema.testcomp as (foo text, bar int8, qux bool);
         """
     )
@@ -238,6 +240,28 @@ def test_dump_composite_null(conn, fmt_in, testcomp):
     assert rec[0] is True, rec[1]
 
 
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_recursive_composite(conn, fmt_in, testcomp):
+    info = CompositeInfo.fetch(conn, "testcomp")
+    info2 = CompositeInfo.fetch(conn, "testcomp2")
+
+    cur = conn.cursor()
+    register_composite(info, cur)
+    register_composite(info2, cur)
+    testcomp = info.python_type
+    testcomp2 = info2.python_type
+
+    obj = testcomp2(42, testcomp("foo", 1, None))
+    rec = cur.execute(
+        f"""
+        select row(42, row('foo', 1, NULL)::testcomp)::testcomp2 = %(obj){fmt_in.value},
+            %(obj){fmt_in.value}::text
+        """,
+        {"obj": obj},
+    ).fetchone()
+    assert rec[0] is True, rec[1]
+
+
 @pytest.mark.parametrize("fmt_out", pq.Format)
 def test_load_composite(conn, testcomp, fmt_out):
     info = CompositeInfo.fetch(conn, "testcomp")
@@ -256,6 +280,31 @@ def test_load_composite(conn, testcomp, fmt_out):
     assert isinstance(res[0].baz, float)
 
 
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_recursive_composite(conn, testcomp, fmt_out):
+    info = CompositeInfo.fetch(conn, "testcomp")
+    info2 = CompositeInfo.fetch(conn, "testcomp2")
+
+    register_composite(info, conn)
+    register_composite(info2, conn)
+
+    cur = conn.cursor(binary=fmt_out)
+    cur.execute("select row(42, row('hello', 10, 20)::testcomp)::testcomp2")
+    res = cur.fetchone()[0]
+    assert res.qux == 42
+    assert res.quux.foo == "hello"
+    assert res.quux.bar == 10
+    assert res.quux.baz == 20.0
+    assert isinstance(res.quux.baz, float)
+
+    cur.execute("select array[row(42, row('hello', 10, 30)::testcomp)::testcomp2]")
+    res = cur.fetchone()[0]
+    assert len(res) == 1
+    assert res[0].qux == 42
+    assert res[0].quux.baz == 30.0
+    assert isinstance(res[0].quux.baz, float)
+
+
 @pytest.mark.parametrize("fmt_out", pq.Format)
 def test_load_composite_factory(conn, testcomp, fmt_out):
     info = CompositeInfo.fetch(conn, "testcomp")