]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Replace git_row_types() with get_dumper_types/get_loader_types()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 Aug 2021 12:50:41 +0000 (14:50 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 Aug 2021 22:23:47 +0000 (00:23 +0200)
dump_sequence() can make use of the dumpers pre-set by get_dumper_types()
and is now used in composite binary dump and in binary COPY FROM.

The interface should be better iterated because for the latter use cases
the extra info (oids, formats) are just a waste of resources.

Only the Python implementation has been changed so far, the C
implementation will be changed down the line.

psycopg/psycopg/_transform.py
psycopg/psycopg/abc.py
psycopg/psycopg/copy.py
psycopg/psycopg/types/composite.py
psycopg_c/psycopg_c/_psycopg.pyi
psycopg_c/psycopg_c/_psycopg/transform.pyx
tests/test_copy.py
tests/test_copy_async.py

index c4dbbaf52343745fce2c21b2b78327bb50018934..97c74d6e41f89f054d883e5e4777e62b1aaf4b3d 100644 (file)
@@ -109,14 +109,23 @@ class Transformer(AdaptContext):
                 fmt = result.fformat(i) if format is None else format
                 rc.append(self.get_loader(oid, fmt).load)  # type: ignore
 
-    def set_row_types(
-        self, types: Sequence[int], formats: Sequence[pq.Format]
+    def set_dumper_types(
+        self, types: Sequence[int], format: pq.Format
     ) -> None:
-        rc: List[LoadFunc] = []
+        dumpers: List[Optional["Dumper"]] = []
         for i in range(len(types)):
-            rc.append(self.get_loader(types[i], formats[i]).load)
+            dumpers.append(self.get_dumper_by_oid(types[i], format))
 
-        self._row_loaders = rc
+        self._row_dumpers = dumpers
+
+    def set_loader_types(
+        self, types: Sequence[int], format: pq.Format
+    ) -> None:
+        loaders: List[LoadFunc] = []
+        for i in range(len(types)):
+            loaders.append(self.get_loader(types[i], format).load)
+
+        self._row_loaders = loaders
 
     def dump_sequence(
         self, params: Sequence[Any], formats: Sequence[PyFormat]
index 584b2ae037d3dd918c227ce0e1ff5c76e6e0170f..fc86dc441ac6ce6ecc0d0a1d94ef187d67766906 100644 (file)
@@ -193,8 +193,13 @@ class Transformer(Protocol):
     ) -> None:
         ...
 
-    def set_row_types(
-        self, types: Sequence[int], formats: Sequence[pq.Format]
+    def set_dumper_types(
+        self, types: Sequence[int], format: pq.Format
+    ) -> None:
+        ...
+
+    def set_loader_types(
+        self, types: Sequence[int], format: pq.Format
     ) -> None:
         ...
 
index 0e889dacb7a560dcd5bb0046cca11f24cc116b08..61338f6ce38e61a31756ebd42276dd3eca7e1f98 100644 (file)
@@ -96,9 +96,15 @@ class BaseCopy(Generic[ConnectionType]):
         oids = [
             t if isinstance(t, int) else registry.get_oid(t) for t in types
         ]
-        self.formatter.transformer.set_row_types(
-            oids, [self.formatter.format] * len(types)
-        )
+
+        if self._pgresult.status == ExecStatus.COPY_IN:
+            self.formatter.transformer.set_dumper_types(
+                oids, self.formatter.format
+            )
+        else:
+            self.formatter.transformer.set_loader_types(
+                oids, self.formatter.format
+            )
 
     # High level copy protocol generators (state change of the Copy object)
 
@@ -556,10 +562,9 @@ def _format_row_binary(
         out = bytearray()
 
     out += _pack_int2(len(row))
-    for item in row:
-        if item is not None:
-            dumper = tx.get_dumper(item, PyFormat.BINARY)
-            b = dumper.dump(item)
+    adapted, _, _ = tx.dump_sequence(row, [PyFormat.BINARY] * len(row))
+    for b in adapted:
+        if b is not None:
             out += _pack_int4(len(b))
             out += b
         else:
index 950142d76471498c9324b091d9b38bb56e59cf1c..89d793c91a22208c5e61987aaf0f3c6cf5158061 100644 (file)
@@ -11,7 +11,6 @@ from typing import Any, Callable, cast, Iterator, List, Optional
 from typing import Sequence, Tuple, Type
 
 from .. import pq
-from .. import errors as e
 from .. import postgres
 from ..abc import AdaptContext, Buffer
 from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
@@ -77,22 +76,19 @@ class TupleBinaryDumper(RecursiveDumper):
     # Subclasses must set an info
     info: CompositeInfo
 
-    def dump(self, obj: Tuple[Any, ...]) -> bytearray:
-
-        if len(obj) != len(self.info.field_types):
-            raise e.DataError(
-                f"expected a sequence of {len(self.info.field_types)} items,"
-                f" got {len(obj)}"
-            )
+    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+        super().__init__(cls, context)
+        nfields = len(self.info.field_types)
+        self._tx.set_dumper_types(self.info.field_types, self.format)
+        self._formats = [PyFormat.from_pq(self.format)] * nfields
 
+    def dump(self, obj: Tuple[Any, ...]) -> bytearray:
         out = bytearray(pack_len(len(obj)))
-        get_dumper = self._tx.get_dumper_by_oid
+        adapted, _, _ = self._tx.dump_sequence(obj, self._formats)
         for i in range(len(obj)):
-            item = obj[i]
+            b = adapted[i]
             oid = self.info.field_types[i]
-            if item is not None:
-                dumper = get_dumper(oid, self.format)
-                b = dumper.dump(item)
+            if b is not None:
                 out += _pack_oidlen(oid, len(b))
                 out += b
             else:
@@ -178,7 +174,7 @@ class RecordBinaryLoader(RecursiveLoader):
 
     def _config_types(self, data: bytes) -> None:
         oids = [r[0] for r in self._walk_record(data)]
-        self._tx.set_row_types(oids, [pq.Format.BINARY] * len(oids))
+        self._tx.set_loader_types(oids, self.format)
 
 
 class CompositeLoader(RecordLoader):
@@ -201,9 +197,7 @@ class CompositeLoader(RecordLoader):
         )
 
     def _config_types(self, data: bytes) -> None:
-        self._tx.set_row_types(
-            self.fields_types, [pq.Format.TEXT] * len(self.fields_types)
-        )
+        self._tx.set_loader_types(self.fields_types, self.format)
 
 
 class CompositeBinaryLoader(RecordBinaryLoader):
index 5c9fcef6853cc26ad1acbe7e09149efdc45807f7..a6e07e88d7e65efdc0c7bfca864d1eab454c08f5 100644 (file)
@@ -31,8 +31,11 @@ class Transformer(abc.AdaptContext):
         set_loaders: bool = True,
         format: Optional[pq.Format] = None,
     ) -> None: ...
-    def set_row_types(
-        self, types: Sequence[int], formats: Sequence[pq.Format]
+    def set_dumper_types(
+        self, types: Sequence[int], format: pq.Format
+    ) -> None: ...
+    def set_loader_types(
+        self, types: Sequence[int], format: pq.Format
     ) -> None: ...
     def dump_sequence(
         self, params: Sequence[Any], formats: Sequence[PyFormat]
index a68886fc96cb5cd3fb5c7ebe3d267b904bc4771e..ab1fa453c186062e742c370b594995e7eaceb117 100644 (file)
@@ -125,36 +125,31 @@ cdef class Transformer:
         cdef int i
         cdef object tmp
         cdef list types
-        cdef list formats
+        if format is None:
+            format = libpq.PQfformat(res, 0)
+
         if set_loaders:
             types = PyList_New(self._nfields)
-            formats = PyList_New(self._nfields)
             for i in range(self._nfields):
                 tmp = libpq.PQftype(res, i)
                 Py_INCREF(tmp)
                 PyList_SET_ITEM(types, i, tmp)
 
-                tmp = libpq.PQfformat(res, i) if format is None else format
-                Py_INCREF(tmp)
-                PyList_SET_ITEM(formats, i, tmp)
-
-            self._c_set_row_types(self._nfields, types, formats)
+            self._c_loader_types(self._nfields, types, format)
 
-    def set_row_types(self,
-            types: Sequence[int], formats: Sequence[Format]) -> None:
-        self._c_set_row_types(len(types), types, formats)
+    def set_loader_types(self,
+            types: Sequence[int], format: Format) -> None:
+        self._c_loader_types(len(types), types, format)
 
-    cdef void _c_set_row_types(self, Py_ssize_t ntypes, list types, list formats):
+    cdef void _c_loader_types(self, Py_ssize_t ntypes, list types, int format):
         cdef list loaders = PyList_New(ntypes)
 
         # these are used more as Python object than C
         cdef PyObject *oid
-        cdef PyObject *fmt
         cdef PyObject *row_loader
         for i in range(ntypes):
             oid = PyList_GET_ITEM(types, i)
-            fmt = PyList_GET_ITEM(formats, i)
-            row_loader = self._c_get_loader(oid, fmt)
+            row_loader = self._c_get_loader(oid, <PyObject *>format)
             Py_INCREF(<object>row_loader)
             PyList_SET_ITEM(loaders, i, <object>row_loader)
 
index 6994ab4b1b18f81afb166528dabb79c682624e2c..91b5587564249ab3d25b6b7fa7c1aeede1e19306 100644 (file)
@@ -20,7 +20,7 @@ from .utils import gc_collect
 
 eur = "\u20ac"
 
-sample_records = [(Int4(10), Int4(20), "hello"), (Int4(40), None, "world")]
+sample_records = [(10, 20, "hello"), (40, None, "world")]
 
 sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')"
 
@@ -79,6 +79,7 @@ def test_copy_out_iter(conn, format):
         want = [row + b"\n" for row in sample_text.splitlines()]
     else:
         want = sample_binary_rows
+
     cur = conn.cursor()
     with cur.copy(
         f"copy ({sample_values}) to stdout (format {format.name})"
@@ -341,6 +342,22 @@ def test_copy_in_records(conn, format):
     ensure_table(cur, sample_tabledef)
 
     with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+        for row in sample_records:
+            if format == Format.BINARY:
+                row = tuple(Int4(i) if isinstance(i, int) else i for i in row)
+            copy.write_row(row)
+
+    data = cur.execute("select * from copy_in order by 1").fetchall()
+    assert data == sample_records
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_in_records_set_types(conn, format):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+
+    with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+        copy.set_types(["int4", "int4", "text"])
         for row in sample_records:
             copy.write_row(row)
 
index b118f64630ef151749fc1e96e752a462225d35d7..3fbe7941da4b88b0e65a8367c31b4e22fd9e6ab2 100644 (file)
@@ -14,6 +14,7 @@ from psycopg.pq import Format
 from psycopg.types import TypeInfo
 from psycopg.adapt import PyFormat as PgFormat
 from psycopg.types.hstore import register_hstore
+from psycopg.types.numeric import Int4
 
 from .utils import gc_collect
 from .test_copy import sample_text, sample_binary, sample_binary_rows  # noqa
@@ -325,6 +326,25 @@ async def test_copy_in_records(aconn, format):
     async with cur.copy(
         f"copy copy_in from stdin (format {format.name})"
     ) as copy:
+        for row in sample_records:
+            if format == Format.BINARY:
+                row = tuple(Int4(i) if isinstance(i, int) else i for i in row)
+            await copy.write_row(row)
+
+    await cur.execute("select * from copy_in order by 1")
+    data = await cur.fetchall()
+    assert data == sample_records
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_in_records_set_types(aconn, format):
+    cur = aconn.cursor()
+    await ensure_table(cur, sample_tabledef)
+
+    async with cur.copy(
+        f"copy copy_in from stdin (format {format.name})"
+    ) as copy:
+        copy.set_types(["int4", "int4", "text"])
         for row in sample_records:
             await copy.write_row(row)