]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: fix loading of different 'record' types in the same query
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 9 Apr 2023 19:31:27 +0000 (21:31 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 10 Apr 2023 08:45:44 +0000 (10:45 +0200)
Different records in the same query might have different types.
Use a separate transformer for each sequence of types.

docs/news.rst
psycopg/psycopg/types/composite.py
tests/types/test_composite.py

index deff90a0f7f4c8889b7bc5edae90674485f2f101..2b80e05d410c7d9df71577c12a0213496b3548f6 100644 (file)
@@ -17,6 +17,8 @@ Psycopg 3.1.9 (unreleased)
   (:ticket:`#503`).
 - Fix canceling running queries on process interruption in async connections
   (:ticket:`#543`).
+- Fix loading ROW values with different types in the same query using the
+  binary protocol (:ticket:`#545`).
 
 
 Current release
index ef8f6d4736d665ae575e0ab93f7e50bba1040679..968ee6206d6383bf5a1800d135e424798ecdac08 100644 (file)
@@ -7,12 +7,12 @@ Support for composite types adaptation.
 import re
 import struct
 from collections import namedtuple
-from typing import Any, Callable, cast, Iterator, List, Optional
+from typing import Any, Callable, cast, Dict, Iterator, List, Optional
 from typing import Sequence, Tuple, Type
 
 from .. import pq
+from .. import abc
 from .. import postgres
-from ..abc import AdaptContext, Buffer
 from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader
 from .._struct import pack_len, unpack_len
 from ..postgres import TEXT_OID
@@ -22,7 +22,7 @@ from .._encodings import _as_python_identifier
 _struct_oidlen = struct.Struct("!Ii")
 _pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack)
 _unpack_oidlen = cast(
-    Callable[[Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from
+    Callable[[abc.Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from
 )
 
 
@@ -33,7 +33,7 @@ class SequenceDumper(RecursiveDumper):
         if not obj:
             return start + end
 
-        parts: List[Buffer] = [start]
+        parts: List[abc.Buffer] = [start]
 
         for item in obj:
             if item is None:
@@ -72,7 +72,7 @@ class TupleBinaryDumper(RecursiveDumper):
     # Subclasses must set an info
     info: CompositeInfo
 
-    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+    def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
         super().__init__(cls, context)
         nfields = len(self.info.field_types)
         self._tx.set_dumper_types(self.info.field_types, self.format)
@@ -94,11 +94,11 @@ class TupleBinaryDumper(RecursiveDumper):
 
 
 class BaseCompositeLoader(Loader):
-    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+    def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
         super().__init__(oid, context)
         self._tx = Transformer(context)
 
-    def _parse_record(self, data: Buffer) -> Iterator[Optional[bytes]]:
+    def _parse_record(self, data: abc.Buffer) -> Iterator[Optional[bytes]]:
         """
         Split a non-empty representation of a composite type into components.
 
@@ -130,7 +130,7 @@ class BaseCompositeLoader(Loader):
 
 
 class RecordLoader(BaseCompositeLoader):
-    def load(self, data: Buffer) -> Tuple[Any, ...]:
+    def load(self, data: abc.Buffer) -> Tuple[Any, ...]:
         if data == b"()":
             return ()
 
@@ -143,38 +143,37 @@ class RecordLoader(BaseCompositeLoader):
 
 class RecordBinaryLoader(Loader):
     format = pq.Format.BINARY
-    _types_set = False
 
-    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+    def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
         super().__init__(oid, context)
-        self._tx = Transformer(context)
-
-    def load(self, data: Buffer) -> Tuple[Any, ...]:
-        if not self._types_set:
-            self._config_types(data)
-            self._types_set = True
-
-        return self._tx.load_sequence(
-            tuple(
-                data[offset : offset + length] if length != -1 else None
-                for _, offset, length in self._walk_record(data)
-            )
-        )
-
-    def _walk_record(self, data: Buffer) -> Iterator[Tuple[int, int, int]]:
-        """
-        Yield a sequence of (oid, offset, length) for the content of the record
-        """
+        self._ctx = context
+        # Cache a transformer for each sequence of oid found.
+        # Usually there will be only one, but if there is more than one
+        # row in the same query (in different columns, or even in different
+        # records), oids might differ and we'd need separate transformers.
+        self._txs: Dict[Tuple[int, ...], abc.Transformer] = {}
+
+    def load(self, data: abc.Buffer) -> Tuple[Any, ...]:
         nfields = unpack_len(data, 0)[0]
-        i = 4
+        offset = 4
+        oids = []
+        record = []
         for _ in range(nfields):
-            oid, length = _unpack_oidlen(data, i)
-            yield oid, i + 8, length
-            i += (8 + length) if length > 0 else 8
+            oid, length = _unpack_oidlen(data, offset)
+            offset += 8
+            record.append(data[offset : offset + length] if length != -1 else None)
+            oids.append(oid)
+            if length >= 0:
+                offset += length
+
+        key = tuple(oids)
+        try:
+            tx = self._txs[key]
+        except KeyError:
+            tx = self._txs[key] = Transformer(self._ctx)
+            tx.set_loader_types(oids, self.format)
 
-    def _config_types(self, data: Buffer) -> None:
-        oids = [r[0] for r in self._walk_record(data)]
-        self._tx.set_loader_types(oids, self.format)
+        return tx.load_sequence(tuple(record))
 
 
 class CompositeLoader(RecordLoader):
@@ -182,7 +181,7 @@ class CompositeLoader(RecordLoader):
     fields_types: List[int]
     _types_set = False
 
-    def load(self, data: Buffer) -> Any:
+    def load(self, data: abc.Buffer) -> Any:
         if not self._types_set:
             self._config_types(data)
             self._types_set = True
@@ -194,7 +193,7 @@ class CompositeLoader(RecordLoader):
             *self._tx.load_sequence(tuple(self._parse_record(data[1:-1])))
         )
 
-    def _config_types(self, data: Buffer) -> None:
+    def _config_types(self, data: abc.Buffer) -> None:
         self._tx.set_loader_types(self.fields_types, self.format)
 
 
@@ -202,14 +201,14 @@ class CompositeBinaryLoader(RecordBinaryLoader):
     format = pq.Format.BINARY
     factory: Callable[..., Any]
 
-    def load(self, data: Buffer) -> Any:
+    def load(self, data: abc.Buffer) -> Any:
         r = super().load(data)
         return type(self).factory(*r)
 
 
 def register_composite(
     info: CompositeInfo,
-    context: Optional[AdaptContext] = None,
+    context: Optional[abc.AdaptContext] = None,
     factory: Optional[Callable[..., Any]] = None,
 ) -> None:
     """Register the adapters to load and dump a composite type.
@@ -279,7 +278,7 @@ def register_composite(
         info.python_type = factory
 
 
-def register_default_adapters(context: AdaptContext) -> None:
+def register_default_adapters(context: abc.AdaptContext) -> None:
     adapters = context.adapters
     adapters.register_dumper(tuple, TupleDumper)
     adapters.register_loader("record", RecordLoader)
index 49e734bf44bef432cbeb91cffe54953e01ff7a19..ad7db6e12fbc7176d8dadedf506165eff276ad1d 100644 (file)
@@ -34,6 +34,24 @@ def test_load_record(conn, want, rec):
     assert res == want
 
 
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_different_records_cols(conn, fmt_out):
+    cur = conn.cursor(binary=fmt_out)
+    res = cur.execute(
+        "select row('foo'::text), row('bar'::text, 'baz'::text)"
+    ).fetchone()
+    assert res == (("foo",), ("bar", "baz"))
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_different_records_rows(conn, fmt_out):
+    cur = conn.cursor(binary=fmt_out)
+    res = cur.execute(
+        "values (row('foo'::text)), (row('bar'::text, 'baz'::text))"
+    ).fetchall()
+    assert res == [(("foo",),), (("bar", "baz"),)]
+
+
 @pytest.mark.parametrize("rec, obj", tests_str)
 def test_dump_tuple(conn, rec, obj):
     cur = conn.cursor()