From: Daniele Varrazzo Date: Fri, 27 Aug 2021 16:16:33 +0000 (+0200) Subject: Return only the params from Transformer.dump_sequence() X-Git-Tag: 3.0.beta1~29^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7667819b4f87584503c3b77d9ceb8f86888bb8be;p=thirdparty%2Fpsycopg.git Return only the params from Transformer.dump_sequence() The types and formats can be read as object attributes. --- diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py index 25091f438..2c1f72dcc 100644 --- a/psycopg/psycopg/_queries.py +++ b/psycopg/psycopg/_queries.py @@ -12,7 +12,7 @@ from functools import lru_cache from . import pq from . import errors as e from .sql import Composable -from .abc import Query, Params +from .abc import Buffer, Query, Params from ._enums import PyFormat if TYPE_CHECKING: @@ -38,7 +38,7 @@ class PostgresQuery: def __init__(self, transformer: "Transformer"): self._tx = transformer - self.params: Optional[List[Optional[bytes]]] = None + self.params: Optional[Sequence[Optional[Buffer]]] = None # these are tuples so they can be used as keys e.g. in prepared stmts self.types: Tuple[int, ...] = () @@ -91,9 +91,9 @@ class PostgresQuery: self._parts, vars, self._order ) assert self._want_formats is not None - self.params, self.types, self.formats = self._tx.dump_sequence( - params, self._want_formats - ) + self.params = self._tx.dump_sequence(params, self._want_formats) + self.types = self._tx.types or () + self.formats = self._tx.formats else: self.params = None self.types = () diff --git a/psycopg/psycopg/_transform.py b/psycopg/psycopg/_transform.py index 97c74d6e4..b15511e4b 100644 --- a/psycopg/psycopg/_transform.py +++ b/psycopg/psycopg/_transform.py @@ -11,7 +11,7 @@ from collections import defaultdict from . import pq from . import postgres from . import errors as e -from .abc import LoadFunc, AdaptContext, PyFormat, DumperKey +from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey from .rows import Row, RowMaker from .postgres import INVALID_OID @@ -42,7 +42,12 @@ class Transformer(AdaptContext): _adapters: "AdaptersMap" _pgresult: Optional["PGresult"] = None + types: Tuple[int, ...] + formats: List[pq.Format] + def __init__(self, context: Optional[AdaptContext] = None): + self.types = () + self.formats = [] # WARNING: don't store context, or you'll create a loop with the Cursor if context: @@ -91,52 +96,65 @@ class Transformer(AdaptContext): ) -> None: self._pgresult = result - self._ntuples: int - self._nfields: int if not result: self._nfields = self._ntuples = 0 if set_loaders: self._row_loaders = [] return - nf = self._nfields = result.nfields self._ntuples = result.ntuples + nf = self._nfields = result.nfields + + if not set_loaders: + return + + if not nf: + self._row_loaders = [] + return - if set_loaders: - rc = self._row_loaders = [] - for i in range(nf): - oid = result.ftype(i) - fmt = result.fformat(i) if format is None else format - rc.append(self.get_loader(oid, fmt).load) # type: ignore + fmt: pq.Format + fmt = result.fformat(0) if format is None else format # type: ignore + self._row_loaders = [ + self.get_loader(result.ftype(i), fmt).load for i in range(nf) + ] def set_dumper_types( self, types: Sequence[int], format: pq.Format ) -> None: - dumpers: List[Optional["Dumper"]] = [] - for i in range(len(types)): - dumpers.append(self.get_dumper_by_oid(types[i], format)) - - self._row_dumpers = dumpers + self._row_dumpers = [ + self.get_dumper_by_oid(oid, format) for oid in types + ] + self.types = tuple(types) + self.formats = [format] * len(types) def set_loader_types( self, types: Sequence[int], format: pq.Format ) -> None: - loaders: List[LoadFunc] = [] - for i in range(len(types)): - loaders.append(self.get_loader(types[i], format).load) - - self._row_loaders = loaders + self._row_loaders = [ + self.get_loader(oid, format).load for oid in types + ] def dump_sequence( self, params: Sequence[Any], formats: Sequence[PyFormat] - ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]: - ps: List[Optional[bytes]] = [None] * len(params) - ts = [INVALID_OID] * len(params) - fs: List[pq.Format] = [pq.Format.TEXT] * len(params) + ) -> Sequence[Optional[Buffer]]: + out: List[Optional[Buffer]] = [None] * len(params) - dumpers = self._row_dumpers + change_state = False + + dumpers: List[Optional[Dumper]] = self._row_dumpers + types: Optional[List[int]] = None + pqformats: Optional[List[pq.Format]] = None + + # If we have dumpers, it means dump_sequnece or set_dumper_types were + # called already, in which case self.types and self.formats are set to + # sequences of the right size. We may change their contents if + # now we find a dumper we didn't have before, for instance because in + # an executemany the first records has a null, the second has a value. if not dumpers: - dumpers = self._row_dumpers = [None] * len(params) + change_state = True + dumpers = [None] * len(params) + types = [INVALID_OID] * len(params) + pqformats = [pq.Format.TEXT] * len(params) for i in range(len(params)): param = params[i] @@ -144,11 +162,24 @@ class Transformer(AdaptContext): dumper = dumpers[i] if not dumper: dumper = dumpers[i] = self.get_dumper(param, formats[i]) - ps[i] = dumper.dump(param) - ts[i] = dumper.oid - fs[i] = dumper.format - - return ps, tuple(ts), fs + change_state = True + if not types: + types = list(self.types) + types[i] = dumper.oid + if not pqformats: + pqformats = list(self.formats) + pqformats[i] = dumper.format + + out[i] = dumper.dump(param) + + if change_state: + self._row_dumpers = dumpers + assert types is not None + self.types = tuple(types) + assert pqformats is not None + self.formats = pqformats + + return out def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper": """ diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py index fc86dc441..d6c164998 100644 --- a/psycopg/psycopg/abc.py +++ b/psycopg/psycopg/abc.py @@ -169,6 +169,10 @@ class Loader(Protocol): class Transformer(Protocol): + + types: Tuple[int, ...] + formats: Sequence[pq.Format] + def __init__(self, context: Optional[AdaptContext] = None): ... @@ -205,7 +209,7 @@ class Transformer(Protocol): def dump_sequence( self, params: Sequence[Any], formats: Sequence[PyFormat] - ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]: + ) -> Sequence[Optional[Buffer]]: ... def get_dumper(self, obj: Any, format: PyFormat) -> Dumper: diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index 61338f6ce..3e15e1547 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -562,7 +562,7 @@ def _format_row_binary( out = bytearray() out += _pack_int2(len(row)) - adapted, _, _ = tx.dump_sequence(row, [PyFormat.BINARY] * len(row)) + adapted = tx.dump_sequence(row, [PyFormat.BINARY] * len(row)) for b in adapted: if b is not None: out += _pack_int4(len(b)) diff --git a/psycopg/psycopg/types/composite.py b/psycopg/psycopg/types/composite.py index 89d793c91..50eea4d8a 100644 --- a/psycopg/psycopg/types/composite.py +++ b/psycopg/psycopg/types/composite.py @@ -80,11 +80,11 @@ class TupleBinaryDumper(RecursiveDumper): super().__init__(cls, context) nfields = len(self.info.field_types) self._tx.set_dumper_types(self.info.field_types, self.format) - self._formats = [PyFormat.from_pq(self.format)] * nfields + self._formats = (PyFormat.from_pq(self.format),) * nfields def dump(self, obj: Tuple[Any, ...]) -> bytearray: out = bytearray(pack_len(len(obj))) - adapted, _, _ = self._tx.dump_sequence(obj, self._formats) + adapted = self._tx.dump_sequence(obj, self._formats) for i in range(len(obj)): b = adapted[i] oid = self.info.field_types[i] diff --git a/psycopg_c/psycopg_c/_psycopg.pyi b/psycopg_c/psycopg_c/_psycopg.pyi index a6e07e88d..ebcb6adde 100644 --- a/psycopg_c/psycopg_c/_psycopg.pyi +++ b/psycopg_c/psycopg_c/_psycopg.pyi @@ -17,6 +17,8 @@ from psycopg.pq.abc import PGconn, PGresult from psycopg.connection import BaseConnection class Transformer(abc.AdaptContext): + types: Tuple[int, ...] + formats: Sequence[pq.Format] def __init__(self, context: Optional[abc.AdaptContext] = None): ... @property def connection(self) -> Optional[BaseConnection[Any]]: ... @@ -39,7 +41,7 @@ class Transformer(abc.AdaptContext): ) -> None: ... def dump_sequence( self, params: Sequence[Any], formats: Sequence[PyFormat] - ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]: ... + ) -> Sequence[Optional[abc.Buffer]]: ... def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ... def get_dumper_by_oid(self, oid: int, format: pq.Format) -> abc.Dumper: ... def load_rows( diff --git a/tests/test_copy.py b/tests/test_copy.py index 91b558756..500251d5e 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -488,7 +488,7 @@ def test_copy_query(conn): cur = conn.cursor() with cur.copy("copy (select 1) to stdout") as copy: assert cur._query.query == b"copy (select 1) to stdout" - assert cur._query.params is None + assert not cur._query.params list(copy) diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 3fbe7941d..40ac2cc04 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -478,7 +478,7 @@ async def test_copy_query(aconn): cur = aconn.cursor() async with cur.copy("copy (select 1) to stdout") as copy: assert cur._query.query == b"copy (select 1) to stdout" - assert cur._query.params is None + assert not cur._query.params async for record in copy: pass diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 686b20902..4445d1898 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -403,7 +403,7 @@ def test_query_params_execute(conn): cur.execute("select 1") assert cur._query.query == b"select 1" - assert cur._query.params is None + assert not cur._query.params with pytest.raises(psycopg.DataError): cur.execute("select %t::int", ["wat"]) diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index 4f5133456..c561e9fab 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -405,7 +405,7 @@ async def test_query_params_execute(aconn): await cur.execute("select 1") assert cur._query.query == b"select 1" - assert cur._query.params is None + assert not cur._query.params with pytest.raises(psycopg.DataError): await cur.execute("select %t::int", ["wat"]) diff --git a/tests/test_query.py b/tests/test_query.py index 0314dd8cf..9765ea9d3 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -68,7 +68,7 @@ def test_split_query_bad(input): @pytest.mark.parametrize( "query, params, want, wformats, wparams", [ - (b"", None, b"", None, None), + (b"", None, b"", (), ()), (b"", [], b"", [], []), (b"%%", [], b"%", [], []), (b"select %t", (1,), b"select $1", [pq.Format.TEXT], [b"1"]),