]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: fix loading of different 'record' types in the same query
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 9 Apr 2023 19:31:27 +0000 (21:31 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 10 Apr 2023 08:44:35 +0000 (10:44 +0200)
Different records in the same query might have different types.
Use a separate transformer for each sequence of types.

docs/news.rst
psycopg/psycopg/types/composite.py
tests/types/test_composite.py

index 36166609f071b9a1e456dd423380020aa4577c45..5255c149f0550e81a6526b83bafefcf8b983769c 100644 (file)
@@ -17,6 +17,8 @@ Psycopg 3.1.9 (unreleased)
   (:ticket:`#503`).
 - Fix canceling running queries on process interruption in async connections
   (:ticket:`#543`).
+- Fix loading ROW values with different types in the same query using the
+  binary protocol (:ticket:`#545`).
 
 
 Current release
index e2a5d0d94bc335ce9ad09b845df38e542d8176b7..b643d9bec6b9c0c832704db073d9a4d0ae45217d 100644 (file)
@@ -7,13 +7,13 @@ Support for composite types adaptation.
 import re
 import struct
 from collections import namedtuple
-from typing import Any, Callable, cast, Iterator, List, Optional
+from typing import Any, Callable, cast, Dict, Iterator, List, Optional
 from typing import Sequence, Tuple, Type, TYPE_CHECKING
 
 from .. import pq
+from .. import abc
 from .. import sql
 from .. import postgres
-from ..abc import AdaptContext, Buffer, Query
 from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader
 from .._oids import TEXT_OID
 from .._struct import pack_len, unpack_len
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
 _struct_oidlen = struct.Struct("!Ii")
 _pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack)
 _unpack_oidlen = cast(
-    Callable[[Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from
+    Callable[[abc.Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from
 )
 
 
@@ -50,7 +50,7 @@ class CompositeInfo(TypeInfo):
         self.python_type: Optional[type] = None
 
     @classmethod
-    def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
+    def _get_info_query(cls, conn: "BaseConnection[Any]") -> abc.Query:
         return sql.SQL(
             """\
 SELECT
@@ -87,7 +87,7 @@ class SequenceDumper(RecursiveDumper):
         if not obj:
             return start + end
 
-        parts: List[Buffer] = [start]
+        parts: List[abc.Buffer] = [start]
 
         for item in obj:
             if item is None:
@@ -126,7 +126,7 @@ class TupleBinaryDumper(RecursiveDumper):
     # Subclasses must set an info
     info: CompositeInfo
 
-    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+    def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
         super().__init__(cls, context)
         nfields = len(self.info.field_types)
         self._tx.set_dumper_types(self.info.field_types, self.format)
@@ -148,11 +148,11 @@ class TupleBinaryDumper(RecursiveDumper):
 
 
 class BaseCompositeLoader(Loader):
-    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+    def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
         super().__init__(oid, context)
         self._tx = Transformer(context)
 
-    def _parse_record(self, data: Buffer) -> Iterator[Optional[bytes]]:
+    def _parse_record(self, data: abc.Buffer) -> Iterator[Optional[bytes]]:
         """
         Split a non-empty representation of a composite type into components.
 
@@ -184,7 +184,7 @@ class BaseCompositeLoader(Loader):
 
 
 class RecordLoader(BaseCompositeLoader):
-    def load(self, data: Buffer) -> Tuple[Any, ...]:
+    def load(self, data: abc.Buffer) -> Tuple[Any, ...]:
         if data == b"()":
             return ()
 
@@ -197,38 +197,37 @@ class RecordLoader(BaseCompositeLoader):
 
 class RecordBinaryLoader(Loader):
     format = pq.Format.BINARY
-    _types_set = False
 
-    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+    def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
         super().__init__(oid, context)
-        self._tx = Transformer(context)
-
-    def load(self, data: Buffer) -> Tuple[Any, ...]:
-        if not self._types_set:
-            self._config_types(data)
-            self._types_set = True
-
-        return self._tx.load_sequence(
-            tuple(
-                data[offset : offset + length] if length != -1 else None
-                for _, offset, length in self._walk_record(data)
-            )
-        )
-
-    def _walk_record(self, data: Buffer) -> Iterator[Tuple[int, int, int]]:
-        """
-        Yield a sequence of (oid, offset, length) for the content of the record
-        """
+        self._ctx = context
+        # Cache a transformer for each sequence of oid found.
+        # Usually there will be only one, but if there is more than one
+        # row in the same query (in different columns, or even in different
+        # records), oids might differ and we'd need separate transformers.
+        self._txs: Dict[Tuple[int, ...], abc.Transformer] = {}
+
+    def load(self, data: abc.Buffer) -> Tuple[Any, ...]:
         nfields = unpack_len(data, 0)[0]
-        i = 4
+        offset = 4
+        oids = []
+        record = []
         for _ in range(nfields):
-            oid, length = _unpack_oidlen(data, i)
-            yield oid, i + 8, length
-            i += (8 + length) if length > 0 else 8
+            oid, length = _unpack_oidlen(data, offset)
+            offset += 8
+            record.append(data[offset : offset + length] if length != -1 else None)
+            oids.append(oid)
+            if length >= 0:
+                offset += length
+
+        key = tuple(oids)
+        try:
+            tx = self._txs[key]
+        except KeyError:
+            tx = self._txs[key] = Transformer(self._ctx)
+            tx.set_loader_types(oids, self.format)
 
-    def _config_types(self, data: Buffer) -> None:
-        oids = [r[0] for r in self._walk_record(data)]
-        self._tx.set_loader_types(oids, self.format)
+        return tx.load_sequence(tuple(record))
 
 
 class CompositeLoader(RecordLoader):
@@ -236,7 +235,7 @@ class CompositeLoader(RecordLoader):
     fields_types: List[int]
     _types_set = False
 
-    def load(self, data: Buffer) -> Any:
+    def load(self, data: abc.Buffer) -> Any:
         if not self._types_set:
             self._config_types(data)
             self._types_set = True
@@ -248,7 +247,7 @@ class CompositeLoader(RecordLoader):
             *self._tx.load_sequence(tuple(self._parse_record(data[1:-1])))
         )
 
-    def _config_types(self, data: Buffer) -> None:
+    def _config_types(self, data: abc.Buffer) -> None:
         self._tx.set_loader_types(self.fields_types, self.format)
 
 
@@ -256,14 +255,14 @@ class CompositeBinaryLoader(RecordBinaryLoader):
     format = pq.Format.BINARY
     factory: Callable[..., Any]
 
-    def load(self, data: Buffer) -> Any:
+    def load(self, data: abc.Buffer) -> Any:
         r = super().load(data)
         return type(self).factory(*r)
 
 
 def register_composite(
     info: CompositeInfo,
-    context: Optional[AdaptContext] = None,
+    context: Optional[abc.AdaptContext] = None,
     factory: Optional[Callable[..., Any]] = None,
 ) -> None:
     """Register the adapters to load and dump a composite type.
@@ -333,7 +332,7 @@ def register_composite(
         info.python_type = factory
 
 
-def register_default_adapters(context: AdaptContext) -> None:
+def register_default_adapters(context: abc.AdaptContext) -> None:
     adapters = context.adapters
     adapters.register_dumper(tuple, TupleDumper)
     adapters.register_loader("record", RecordLoader)
index 49e734bf44bef432cbeb91cffe54953e01ff7a19..ad7db6e12fbc7176d8dadedf506165eff276ad1d 100644 (file)
@@ -34,6 +34,24 @@ def test_load_record(conn, want, rec):
     assert res == want
 
 
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_different_records_cols(conn, fmt_out):
+    cur = conn.cursor(binary=fmt_out)
+    res = cur.execute(
+        "select row('foo'::text), row('bar'::text, 'baz'::text)"
+    ).fetchone()
+    assert res == (("foo",), ("bar", "baz"))
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_different_records_rows(conn, fmt_out):
+    cur = conn.cursor(binary=fmt_out)
+    res = cur.execute(
+        "values (row('foo'::text)), (row('bar'::text, 'baz'::text))"
+    ).fetchall()
+    assert res == [(("foo",),), (("bar", "baz"),)]
+
+
 @pytest.mark.parametrize("rec, obj", tests_str)
 def test_dump_tuple(conn, rec, obj):
     cur = conn.cursor()