]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Return only the params from Transformer.dump_sequence()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 Aug 2021 16:16:33 +0000 (18:16 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 Aug 2021 22:23:47 +0000 (00:23 +0200)
The types and formats can be read as object attributes.

psycopg/psycopg/_queries.py
psycopg/psycopg/_transform.py
psycopg/psycopg/abc.py
psycopg/psycopg/copy.py
psycopg/psycopg/types/composite.py
psycopg_c/psycopg_c/_psycopg.pyi
tests/test_copy.py
tests/test_copy_async.py
tests/test_cursor.py
tests/test_cursor_async.py
tests/test_query.py

index 25091f43821d041c9cfcddeef287be11bd964e33..2c1f72dcc043b23404c7d890091704a2aa0fcb8a 100644 (file)
@@ -12,7 +12,7 @@ from functools import lru_cache
 from . import pq
 from . import errors as e
 from .sql import Composable
-from .abc import Query, Params
+from .abc import Buffer, Query, Params
 from ._enums import PyFormat
 
 if TYPE_CHECKING:
@@ -38,7 +38,7 @@ class PostgresQuery:
     def __init__(self, transformer: "Transformer"):
         self._tx = transformer
 
-        self.params: Optional[List[Optional[bytes]]] = None
+        self.params: Optional[Sequence[Optional[Buffer]]] = None
         # these are tuples so they can be used as keys e.g. in prepared stmts
         self.types: Tuple[int, ...] = ()
 
@@ -91,9 +91,9 @@ class PostgresQuery:
                 self._parts, vars, self._order
             )
             assert self._want_formats is not None
-            self.params, self.types, self.formats = self._tx.dump_sequence(
-                params, self._want_formats
-            )
+            self.params = self._tx.dump_sequence(params, self._want_formats)
+            self.types = self._tx.types or ()
+            self.formats = self._tx.formats
         else:
             self.params = None
             self.types = ()
index 97c74d6e41f89f054d883e5e4777e62b1aaf4b3d..b15511e4b5a60e321262b3da2a1aa6feaa2ad2a7 100644 (file)
@@ -11,7 +11,7 @@ from collections import defaultdict
 from . import pq
 from . import postgres
 from . import errors as e
-from .abc import LoadFunc, AdaptContext, PyFormat, DumperKey
+from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey
 from .rows import Row, RowMaker
 from .postgres import INVALID_OID
 
@@ -42,7 +42,12 @@ class Transformer(AdaptContext):
     _adapters: "AdaptersMap"
     _pgresult: Optional["PGresult"] = None
 
+    types: Tuple[int, ...]
+    formats: List[pq.Format]
+
     def __init__(self, context: Optional[AdaptContext] = None):
+        self.types = ()
+        self.formats = []
 
         # WARNING: don't store context, or you'll create a loop with the Cursor
         if context:
@@ -91,52 +96,65 @@ class Transformer(AdaptContext):
     ) -> None:
         self._pgresult = result
 
-        self._ntuples: int
-        self._nfields: int
         if not result:
             self._nfields = self._ntuples = 0
             if set_loaders:
                 self._row_loaders = []
             return
 
-        nf = self._nfields = result.nfields
         self._ntuples = result.ntuples
+        nf = self._nfields = result.nfields
+
+        if not set_loaders:
+            return
+
+        if not nf:
+            self._row_loaders = []
+            return
 
-        if set_loaders:
-            rc = self._row_loaders = []
-            for i in range(nf):
-                oid = result.ftype(i)
-                fmt = result.fformat(i) if format is None else format
-                rc.append(self.get_loader(oid, fmt).load)  # type: ignore
+        fmt: pq.Format
+        fmt = result.fformat(0) if format is None else format  # type: ignore
+        self._row_loaders = [
+            self.get_loader(result.ftype(i), fmt).load for i in range(nf)
+        ]
 
     def set_dumper_types(
         self, types: Sequence[int], format: pq.Format
     ) -> None:
-        dumpers: List[Optional["Dumper"]] = []
-        for i in range(len(types)):
-            dumpers.append(self.get_dumper_by_oid(types[i], format))
-
-        self._row_dumpers = dumpers
+        self._row_dumpers = [
+            self.get_dumper_by_oid(oid, format) for oid in types
+        ]
+        self.types = tuple(types)
+        self.formats = [format] * len(types)
 
     def set_loader_types(
         self, types: Sequence[int], format: pq.Format
     ) -> None:
-        loaders: List[LoadFunc] = []
-        for i in range(len(types)):
-            loaders.append(self.get_loader(types[i], format).load)
-
-        self._row_loaders = loaders
+        self._row_loaders = [
+            self.get_loader(oid, format).load for oid in types
+        ]
 
     def dump_sequence(
         self, params: Sequence[Any], formats: Sequence[PyFormat]
-    ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]:
-        ps: List[Optional[bytes]] = [None] * len(params)
-        ts = [INVALID_OID] * len(params)
-        fs: List[pq.Format] = [pq.Format.TEXT] * len(params)
+    ) -> Sequence[Optional[Buffer]]:
+        out: List[Optional[Buffer]] = [None] * len(params)
 
-        dumpers = self._row_dumpers
+        change_state = False
+
+        dumpers: List[Optional[Dumper]] = self._row_dumpers
+        types: Optional[List[int]] = None
+        pqformats: Optional[List[pq.Format]] = None
+
+        # If we have dumpers, it means dump_sequnece or set_dumper_types were
+        # called already, in which case self.types and self.formats are set to
+        # sequences of the right size. We may change their contents if
+        # now we find a dumper we didn't have before, for instance because in
+        # an executemany the first records has a null, the second has a value.
         if not dumpers:
-            dumpers = self._row_dumpers = [None] * len(params)
+            change_state = True
+            dumpers = [None] * len(params)
+            types = [INVALID_OID] * len(params)
+            pqformats = [pq.Format.TEXT] * len(params)
 
         for i in range(len(params)):
             param = params[i]
@@ -144,11 +162,24 @@ class Transformer(AdaptContext):
                 dumper = dumpers[i]
                 if not dumper:
                     dumper = dumpers[i] = self.get_dumper(param, formats[i])
-                ps[i] = dumper.dump(param)
-                ts[i] = dumper.oid
-                fs[i] = dumper.format
-
-        return ps, tuple(ts), fs
+                    change_state = True
+                    if not types:
+                        types = list(self.types)
+                    types[i] = dumper.oid
+                    if not pqformats:
+                        pqformats = list(self.formats)
+                    pqformats[i] = dumper.format
+
+                out[i] = dumper.dump(param)
+
+        if change_state:
+            self._row_dumpers = dumpers
+            assert types is not None
+            self.types = tuple(types)
+            assert pqformats is not None
+            self.formats = pqformats
+
+        return out
 
     def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper":
         """
index fc86dc441ac6ce6ecc0d0a1d94ef187d67766906..d6c164998ecf2a07e9a699fea0f8edd5f040e40d 100644 (file)
@@ -169,6 +169,10 @@ class Loader(Protocol):
 
 
 class Transformer(Protocol):
+
+    types: Tuple[int, ...]
+    formats: Sequence[pq.Format]
+
     def __init__(self, context: Optional[AdaptContext] = None):
         ...
 
@@ -205,7 +209,7 @@ class Transformer(Protocol):
 
     def dump_sequence(
         self, params: Sequence[Any], formats: Sequence[PyFormat]
-    ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]:
+    ) -> Sequence[Optional[Buffer]]:
         ...
 
     def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
index 61338f6ce38e61a31756ebd42276dd3eca7e1f98..3e15e15477c4c46c3debb0fccb2f8ea93a636bba 100644 (file)
@@ -562,7 +562,7 @@ def _format_row_binary(
         out = bytearray()
 
     out += _pack_int2(len(row))
-    adapted, _, _ = tx.dump_sequence(row, [PyFormat.BINARY] * len(row))
+    adapted = tx.dump_sequence(row, [PyFormat.BINARY] * len(row))
     for b in adapted:
         if b is not None:
             out += _pack_int4(len(b))
index 89d793c91a22208c5e61987aaf0f3c6cf5158061..50eea4d8a51077162cdd58bada45c65ecb91c110 100644 (file)
@@ -80,11 +80,11 @@ class TupleBinaryDumper(RecursiveDumper):
         super().__init__(cls, context)
         nfields = len(self.info.field_types)
         self._tx.set_dumper_types(self.info.field_types, self.format)
-        self._formats = [PyFormat.from_pq(self.format)] * nfields
+        self._formats = (PyFormat.from_pq(self.format),) * nfields
 
     def dump(self, obj: Tuple[Any, ...]) -> bytearray:
         out = bytearray(pack_len(len(obj)))
-        adapted, _, _ = self._tx.dump_sequence(obj, self._formats)
+        adapted = self._tx.dump_sequence(obj, self._formats)
         for i in range(len(obj)):
             b = adapted[i]
             oid = self.info.field_types[i]
index a6e07e88d7e65efdc0c7bfca864d1eab454c08f5..ebcb6addeda61bf917fb6604f8c280c9f73ce0b1 100644 (file)
@@ -17,6 +17,8 @@ from psycopg.pq.abc import PGconn, PGresult
 from psycopg.connection import BaseConnection
 
 class Transformer(abc.AdaptContext):
+    types: Tuple[int, ...]
+    formats: Sequence[pq.Format]
     def __init__(self, context: Optional[abc.AdaptContext] = None): ...
     @property
     def connection(self) -> Optional[BaseConnection[Any]]: ...
@@ -39,7 +41,7 @@ class Transformer(abc.AdaptContext):
     ) -> None: ...
     def dump_sequence(
         self, params: Sequence[Any], formats: Sequence[PyFormat]
-    ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]: ...
+    ) -> Sequence[Optional[abc.Buffer]]: ...
     def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ...
     def get_dumper_by_oid(self, oid: int, format: pq.Format) -> abc.Dumper: ...
     def load_rows(
index 91b5587564249ab3d25b6b7fa7c1aeede1e19306..500251d5e7b8eb40780d9abf62e76c1d01956182 100644 (file)
@@ -488,7 +488,7 @@ def test_copy_query(conn):
     cur = conn.cursor()
     with cur.copy("copy (select 1) to stdout") as copy:
         assert cur._query.query == b"copy (select 1) to stdout"
-        assert cur._query.params is None
+        assert not cur._query.params
         list(copy)
 
 
index 3fbe7941da4b88b0e65a8367c31b4e22fd9e6ab2..40ac2cc04eeb3049623150998129f621fdd6e61d 100644 (file)
@@ -478,7 +478,7 @@ async def test_copy_query(aconn):
     cur = aconn.cursor()
     async with cur.copy("copy (select 1) to stdout") as copy:
         assert cur._query.query == b"copy (select 1) to stdout"
-        assert cur._query.params is None
+        assert not cur._query.params
         async for record in copy:
             pass
 
index 686b20902aabf034da2b2f5889a45735512d3d57..4445d1898b6427a7b921d8134633f7bdbc49264d 100644 (file)
@@ -403,7 +403,7 @@ def test_query_params_execute(conn):
 
     cur.execute("select 1")
     assert cur._query.query == b"select 1"
-    assert cur._query.params is None
+    assert not cur._query.params
 
     with pytest.raises(psycopg.DataError):
         cur.execute("select %t::int", ["wat"])
index 4f51334563b74876b8194f269ef8630885521563..c561e9fab45d3356b677c85e22b9825816e0762a 100644 (file)
@@ -405,7 +405,7 @@ async def test_query_params_execute(aconn):
 
     await cur.execute("select 1")
     assert cur._query.query == b"select 1"
-    assert cur._query.params is None
+    assert not cur._query.params
 
     with pytest.raises(psycopg.DataError):
         await cur.execute("select %t::int", ["wat"])
index 0314dd8cf64f179b03b458d8cd6a66031cfce4a6..9765ea9d3170a6aff31fe214b03c1c04296ead67 100644 (file)
@@ -68,7 +68,7 @@ def test_split_query_bad(input):
 @pytest.mark.parametrize(
     "query, params, want, wformats, wparams",
     [
-        (b"", None, b"", None, None),
+        (b"", None, b"", (), ()),
         (b"", [], b"", [], []),
         (b"%%", [], b"%", [], []),
         (b"select %t", (1,), b"select $1", [pq.Format.TEXT], [b"1"]),