From 676288ba24146f62e5247e75870d6751738e4ccd Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 9 Apr 2023 21:31:27 +0200 Subject: [PATCH] fix: fix loading of different 'record' types in the same query Different records in the same query might have different types. Use a separate transformer for each sequence of types. --- docs/news.rst | 2 + psycopg/psycopg/types/composite.py | 79 +++++++++++++++--------------- tests/types/test_composite.py | 18 +++++++ 3 files changed, 59 insertions(+), 40 deletions(-) diff --git a/docs/news.rst b/docs/news.rst index deff90a0f..2b80e05d4 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -17,6 +17,8 @@ Psycopg 3.1.9 (unreleased) (:ticket:`#503`). - Fix canceling running queries on process interruption in async connections (:ticket:`#543`). +- Fix loading ROW values with different types in the same query using the + binary protocol (:ticket:`#545`). Current release diff --git a/psycopg/psycopg/types/composite.py b/psycopg/psycopg/types/composite.py index ef8f6d473..968ee6206 100644 --- a/psycopg/psycopg/types/composite.py +++ b/psycopg/psycopg/types/composite.py @@ -7,12 +7,12 @@ Support for composite types adaptation. import re import struct from collections import namedtuple -from typing import Any, Callable, cast, Iterator, List, Optional +from typing import Any, Callable, cast, Dict, Iterator, List, Optional from typing import Sequence, Tuple, Type from .. import pq +from .. import abc from .. import postgres -from ..abc import AdaptContext, Buffer from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader from .._struct import pack_len, unpack_len from ..postgres import TEXT_OID @@ -22,7 +22,7 @@ from .._encodings import _as_python_identifier _struct_oidlen = struct.Struct("!Ii") _pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack) _unpack_oidlen = cast( - Callable[[Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from + Callable[[abc.Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from ) @@ -33,7 +33,7 @@ class SequenceDumper(RecursiveDumper): if not obj: return start + end - parts: List[Buffer] = [start] + parts: List[abc.Buffer] = [start] for item in obj: if item is None: @@ -72,7 +72,7 @@ class TupleBinaryDumper(RecursiveDumper): # Subclasses must set an info info: CompositeInfo - def __init__(self, cls: type, context: Optional[AdaptContext] = None): + def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None): super().__init__(cls, context) nfields = len(self.info.field_types) self._tx.set_dumper_types(self.info.field_types, self.format) @@ -94,11 +94,11 @@ class TupleBinaryDumper(RecursiveDumper): class BaseCompositeLoader(Loader): - def __init__(self, oid: int, context: Optional[AdaptContext] = None): + def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None): super().__init__(oid, context) self._tx = Transformer(context) - def _parse_record(self, data: Buffer) -> Iterator[Optional[bytes]]: + def _parse_record(self, data: abc.Buffer) -> Iterator[Optional[bytes]]: """ Split a non-empty representation of a composite type into components. @@ -130,7 +130,7 @@ class BaseCompositeLoader(Loader): class RecordLoader(BaseCompositeLoader): - def load(self, data: Buffer) -> Tuple[Any, ...]: + def load(self, data: abc.Buffer) -> Tuple[Any, ...]: if data == b"()": return () @@ -143,38 +143,37 @@ class RecordLoader(BaseCompositeLoader): class RecordBinaryLoader(Loader): format = pq.Format.BINARY - _types_set = False - def __init__(self, oid: int, context: Optional[AdaptContext] = None): + def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None): super().__init__(oid, context) - self._tx = Transformer(context) - - def load(self, data: Buffer) -> Tuple[Any, ...]: - if not self._types_set: - self._config_types(data) - self._types_set = True - - return self._tx.load_sequence( - tuple( - data[offset : offset + length] if length != -1 else None - for _, offset, length in self._walk_record(data) - ) - ) - - def _walk_record(self, data: Buffer) -> Iterator[Tuple[int, int, int]]: - """ - Yield a sequence of (oid, offset, length) for the content of the record - """ + self._ctx = context + # Cache a transformer for each sequence of oid found. + # Usually there will be only one, but if there is more than one + # row in the same query (in different columns, or even in different + # records), oids might differ and we'd need separate transformers. + self._txs: Dict[Tuple[int, ...], abc.Transformer] = {} + + def load(self, data: abc.Buffer) -> Tuple[Any, ...]: nfields = unpack_len(data, 0)[0] - i = 4 + offset = 4 + oids = [] + record = [] for _ in range(nfields): - oid, length = _unpack_oidlen(data, i) - yield oid, i + 8, length - i += (8 + length) if length > 0 else 8 + oid, length = _unpack_oidlen(data, offset) + offset += 8 + record.append(data[offset : offset + length] if length != -1 else None) + oids.append(oid) + if length >= 0: + offset += length + + key = tuple(oids) + try: + tx = self._txs[key] + except KeyError: + tx = self._txs[key] = Transformer(self._ctx) + tx.set_loader_types(oids, self.format) - def _config_types(self, data: Buffer) -> None: - oids = [r[0] for r in self._walk_record(data)] - self._tx.set_loader_types(oids, self.format) + return tx.load_sequence(tuple(record)) class CompositeLoader(RecordLoader): @@ -182,7 +181,7 @@ class CompositeLoader(RecordLoader): fields_types: List[int] _types_set = False - def load(self, data: Buffer) -> Any: + def load(self, data: abc.Buffer) -> Any: if not self._types_set: self._config_types(data) self._types_set = True @@ -194,7 +193,7 @@ class CompositeLoader(RecordLoader): *self._tx.load_sequence(tuple(self._parse_record(data[1:-1]))) ) - def _config_types(self, data: Buffer) -> None: + def _config_types(self, data: abc.Buffer) -> None: self._tx.set_loader_types(self.fields_types, self.format) @@ -202,14 +201,14 @@ class CompositeBinaryLoader(RecordBinaryLoader): format = pq.Format.BINARY factory: Callable[..., Any] - def load(self, data: Buffer) -> Any: + def load(self, data: abc.Buffer) -> Any: r = super().load(data) return type(self).factory(*r) def register_composite( info: CompositeInfo, - context: Optional[AdaptContext] = None, + context: Optional[abc.AdaptContext] = None, factory: Optional[Callable[..., Any]] = None, ) -> None: """Register the adapters to load and dump a composite type. @@ -279,7 +278,7 @@ def register_composite( info.python_type = factory -def register_default_adapters(context: AdaptContext) -> None: +def register_default_adapters(context: abc.AdaptContext) -> None: adapters = context.adapters adapters.register_dumper(tuple, TupleDumper) adapters.register_loader("record", RecordLoader) diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py index 49e734bf4..ad7db6e12 100644 --- a/tests/types/test_composite.py +++ b/tests/types/test_composite.py @@ -34,6 +34,24 @@ def test_load_record(conn, want, rec): assert res == want +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_different_records_cols(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + res = cur.execute( + "select row('foo'::text), row('bar'::text, 'baz'::text)" + ).fetchone() + assert res == (("foo",), ("bar", "baz")) + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_different_records_rows(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + res = cur.execute( + "values (row('foo'::text)), (row('bar'::text, 'baz'::text))" + ).fetchall() + assert res == [(("foo",),), (("bar", "baz"),)] + + @pytest.mark.parametrize("rec, obj", tests_str) def test_dump_tuple(conn, rec, obj): cur = conn.cursor() -- 2.47.2