]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added reading row-by-row from Copy
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 9 Jan 2021 03:45:08 +0000 (04:45 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 9 Jan 2021 16:23:57 +0000 (17:23 +0100)
docs/copy.rst
docs/cursor.rst
psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/copy.py
psycopg3_c/psycopg3_c/_psycopg3/transform.pyx
tests/test_copy.py
tests/test_copy_async.py

index dcbe18722dc67f2fd74c2dc5055117e9b810811f..95f8551d07b4880926358220a96a11c9140b212f 100644 (file)
@@ -1,3 +1,5 @@
+.. currentmodule:: psycopg3
+
 .. index::
     pair: COPY; SQL command
 
@@ -12,41 +14,128 @@ it, with some SQL creativity).
 
 .. __: https://www.postgresql.org/docs/current/sql-copy.html
 
-Using `!psycopg3` you can do three things:
+Copy is supported using the `Cursor.copy()` method, passing it a query of the
+form :sql:`COPY ... FROM STDIN` or :sql:`COPY ... TO STDOUT`, and managing the
+resulting `Copy` object in a ``with`` block:
 
-- loading data into the database row-by-row, from a stream of Python objects;
-- loading data into the database block-by-block, with data already formatted in
-  a way suitable for :sql:`COPY FROM`;
-- reading data from the database block-by-block, with data emitted by a
-  :sql:`COPY TO` statement.
+.. code:: python
 
-The missing quadrant, copying data from the database row-by-row, is not
-covered by COPY because that's pretty much normal querying, and :sql:`COPY TO`
-doesn't offer enough metadata to decode the data to Python objects.
+    with cursor.copy("COPY table_name (col1, col2) FROM STDIN") as copy:
+        # pass data to the 'copy' object using write()/write_row()
 
-The first option is the most powerful, because it allows to load data into the
-database from any Python iterable (a list of tuple, or any iterable of
-sequences): the Python values are adapted as they would be in normal querying.
-To perform such operation use a :sql:`COPY [table] FROM STDIN` with
-`Cursor.copy()` and use `~Copy.write_row()` on the resulting object in a
-``with`` block. On exiting the block the operation will be concluded:
+You can compose a dynamically a COPY statement by using objects from the
+`psycopg3.sql` module:
 
 .. code:: python
 
-    with cursor.copy("COPY table_name (col1, col2) FROM STDIN") as copy:
-        for row in source:
-            copy.write_row(row)
+    with cursor.copy(
+        sql.SQL("COPY {} TO STDOUT").format(sql.Identifier("table_name"))
+    ) as copy:
+        # read data from the 'copy' object using read()/read_row()
+
+The connection is subject to the usual transaction behaviour, so, unless the
+connection is in autocommit, at the end of the COPY operation you will still
+have to commit the pending changes and you can still roll them back. See
+:ref:`transactions` for details.
+
+
+Writing data row-by-row
+-----------------------
+
+Using a copy operation you can load data into the database from any Python
+iterable (a list of tuple, or any iterable of sequences): the Python values
+are adapted as they would be in normal querying. To perform such operation use
+a :sql:`COPY ... FROM STDIN` with `Cursor.copy()` and use `~Copy.write_row()`
+on the resulting object in a ``with`` block. On exiting the block the
+operation will be concluded:
+
+.. code:: python
+
+    records = [(10, 20, "hello"), (40, None, "world")]
+
+    with cursor.copy("COPY sample (col1, col2, col3) FROM STDIN") as copy:
+        for record in records:
+            copy.write_row(record)
 
 If an exception is raised inside the block, the operation is interrupted and
-the records inserted so far discarded.
+the records inserted so far are discarded.
+
+In order to read or write from `!Copy` row-by-row you must not specify
+:sql:`COPY` options such as :sql:`FORMAT CSV`, :sql:`DELIMITER`, :sql:`NULL`:
+please leave these details alone, thank you :)
+
+Binary copy is supported by specifying :sql:`FORMAT BINARY` in the :sql:`COPY`
+statement. In order to load binary data, all the types passed to the database
+must have a binary dumper registered (see see :ref:`binary-data`).
+
+Note that PostgreSQL is particularly finicky when loading data in binary mode
+and will apply *no cast rule*. This means that e.g. passing a Python `!int`
+object to an :sql:`integer` column (aka :sql:`int4`) will likely fail, because
+the default `!int` `~adapt.Dumper` will use the :sql:`bigint` aka :sql:`int8`
+format. You can work around the problem by registering the right binary dumper
+on the cursor or using the right data wrapper (see :ref:`adaptation`).
+
+
+Reading data row-by-row
+-----------------------
+
+You can also do the opposite, reading rows out of a :sql:`COPY ... TO STDOUT`
+operation, by iterating on `~Copy.rows()`. However this is not something you
+may want to do normally: usually the normal query process will be easier to
+use.
+
+PostgreSQL, currently, doesn't give complete type information on :sql:`COPY
+TO`, so the rows returned will have unparsed data, as strings or bytes,
+according to the format.
+
+.. code:: python
+
+    with cur.copy("COPY (VALUES (10::int, current_date)) TO STDOUT") as copy:
+        for row in copy.rows():
+            print(row)  # return unparsed data: ('10', '2046-12-24')
+
+You can improve the results by using `~Copy.set_types()` before reading, but
+you have to specify them yourselves.
+
+.. code:: python
+
+    from psycopg3.oids import builtins
+
+    with cur.copy("COPY (VALUES (10::int, current_date)) TO STDOUT") as copy:
+        copy.set_types([builtins["int4"].oid, builtins["date"].oid])
+        for row in copy.rows():
+            print(row)  # (10, datetime.date(2046, 12, 24))
+
+.. admonition:: TODO
+
+    Document the `!builtins` register... but more likely do something
+    better such as allowing to pass type names, unifying `TypeRegistry` and
+    `AdaptContext`, none of which I have documented, so you haven't seen
+    anything... ðŸ‘€
+
+
+Copying block-by-block
+----------------------
 
 If data is already formatted in a way suitable for copy (for instance because
 it is coming from a file resulting from a previous `COPY TO` operation) it can
-be loaded using `Copy.write()` instead.
+be loaded into the database using `Copy.write()` instead.
+
+.. code:: python
 
-In order to read data in :sql:`COPY` format you can use a :sql:`COPY TO
+    with open("data", "r") as f:
+        with cursor.copy("COPY data FROM STDIN") as copy:
+            while data := f.read(BLOCK_SIZE):
+                copy.write(data)
+
+In this case you can use any :sql:`COPY` option and format, as long as the
+input data is compatible. Data can be passed as `!str`, if the copy is in
+:sql:`FORMAT TEXT`, or as `!bytes`, which works with both :sql:`FORMAT TEXT`
+and :sql:`FORMAT BINARY`.
+
+In order to produce data in :sql:`COPY` format you can use a :sql:`COPY ... TO
 STDOUT` statement and iterate over the resulting `Copy` object, which will
-produce `!bytes`:
+produce a stream of `!bytes`:
 
 .. code:: python
 
@@ -55,16 +144,20 @@ produce `!bytes`:
             for data in copy:
                 f.write(data)
 
-Asynchronous operations are supported using the same patterns on an
-`AsyncConnection`. For instance, if `!f` is an object supporting an
-asynchronous `!read()` method returning :sql:`COPY` data, a fully-async copy
-operation could be:
+
+Asynchronous copy support
+-------------------------
+
+Asynchronous operations are supported using the same patterns as above, using
+the objects obtained by an `AsyncConnection`. For instance, if `!f` is an
+object supporting an asynchronous `!read()` method returning :sql:`COPY` data,
+a fully-async copy operation could be:
 
 .. code:: python
 
     async with cursor.copy("COPY data FROM STDIN") as copy:
-        while data := await f.read()
+        while data := await f.read():
             await copy.write(data)
 
-Binary data can be produced and consumed using :sql:`FORMAT BINARY` in the
-:sql:`COPY` command: see :ref:`binary-data` for details and limitations.
+The `AsyncCopy` object documentation describe the signature of the
+asynchronous methods and the differences from its sync `Copy` counterpart.
index 3dce837337bd386838fbe195c9ea32ac4e47b844..6f5368dc3d0c4ec1b646c488410b44f7bc6da20d 100644 (file)
@@ -187,28 +187,37 @@ Cursor support objects
 
     See :ref:`copy` for details.
 
+    .. automethod:: write_row
+
+        The data in the tuple will be converted as configured on the cursor;
+        see :ref:`adaptation` for details.
+
+    .. automethod:: write
     .. automethod:: read
 
         Instead of using `!read()` you can even iterate on the object to read
         its data row by row, using ``for row in copy: ...``.
 
-    .. automethod:: write
-    .. automethod:: write_row
-
-        The data in the tuple will be converted as configured on the cursor;
-        see :ref:`adaptation` for details.
+    .. automethod:: rows
+    .. automethod:: read_row
+    .. automethod:: set_types
 
 
 .. autoclass:: AsyncCopy()
 
-    The object is normally returned by ``async with`` `AsyncCursor.copy()`. Its methods are
-    the same of the `Copy` object but offering an `asyncio` interface
-    (`await`, `async for`, `async with`).
+    The object is normally returned by ``async with`` `AsyncCursor.copy()`.
+    Its methods are similar to the ones of the `Copy` object but offering an
+    `asyncio` interface (`await`, `async for`, `async with`).
 
+    .. automethod:: write_row
+    .. automethod:: write
     .. automethod:: read
 
         Instead of using `!read()` you can even iterate on the object to read
         its data row by row, using ``async for row in copy: ...``.
 
-    .. automethod:: write
-    .. automethod:: write_row
+    .. automethod:: rows
+
+        Use it as `async for record in copy.rows():` ...
+
+    .. automethod:: read_row
index 1f792f5634e2333f2369186e3ff7d61eed44274c..cbc6ad00e340350c1c48e3da1f2341d8912957c5 100644 (file)
@@ -172,6 +172,12 @@ class Transformer(AdaptContext):
     def load_sequence(
         self, record: Sequence[Optional[bytes]]
     ) -> Tuple[Any, ...]:
+        if len(self._row_loaders) != len(record):
+            raise e.ProgrammingError(
+                f"cannot load sequence of {len(record)} items:"
+                f" {len(self._row_loaders)} loaders registered"
+            )
+
         return tuple(
             (self._row_loaders[i](val) if val is not None else None)
             for i, val in enumerate(record)
index 33481fd9a0b1dcb43537d7e6ffea62ec671c0b05..0d91961d9aebafc3686972680d1a42948896bc36 100644 (file)
@@ -7,8 +7,8 @@ psycopg3 copy support
 import re
 import struct
 from types import TracebackType
-from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic
-from typing import Any, Dict, Match, Optional, Sequence, Type, Union
+from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union
+from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple
 from typing_extensions import Protocol
 
 from . import pq
@@ -57,14 +57,25 @@ class BaseCopy(Generic[ConnectionType]):
         self._format_row: FormatFunc
         if self.format == Format.TEXT:
             self._format_row = format_row_text
+            self._parse_row = parse_row_text
         else:
             self._format_row = format_row_binary
+            self._parse_row = parse_row_binary
 
     def __repr__(self) -> str:
         cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
         info = pq.misc.connection_summary(self._pgconn)
         return f"<{cls} {info} at 0x{id(self):x}>"
 
+    def set_types(self, types: Sequence[int]) -> None:
+        """
+        Set the types expected out of a :sql:`COPY TO` operation.
+
+        Without setting the types, the data from :sql:`COPY TO` will be
+        returned as unparsed strings or bytes.
+        """
+        self.transformer.set_row_types(types, [self.format] * len(types))
+
     # High level copy protocol generators (state change of the Copy object)
 
     def _read_gen(self) -> PQGen[bytes]:
@@ -81,6 +92,19 @@ class BaseCopy(Generic[ConnectionType]):
         self.cursor._rowcount = nrows if nrows is not None else -1
         return b""
 
+    def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]:
+        data = yield from self._read_gen()
+        if not data:
+            return None
+        if self.format == Format.BINARY:
+            if not self._signature_sent:
+                assert data.startswith(_binary_signature)
+                self._signature_sent = True
+                data = data[len(_binary_signature) :]
+            elif data == _binary_trailer:
+                return None
+        return self._parse_row(data, self.transformer)
+
     def _write_gen(self, buffer: Union[str, bytes]) -> PQGen[None]:
         # if write() was called, assume the header was sent together with the
         # first block of data.
@@ -175,22 +199,48 @@ class Copy(BaseCopy["Connection"]):
     __module__ = "psycopg3"
 
     def read(self) -> bytes:
-        """Read a row of data after a :sql:`COPY TO` operation.
+        """
+        Read an unparsed row from a table after a :sql:`COPY TO` operation.
 
         Return an empty bytes string when the data is finished.
         """
         return self.connection.wait(self._read_gen())
 
+    def rows(self) -> Iterator[Tuple[Any, ...]]:
+        """
+        Iterate on the result of a :sql:`COPY TO` operation record by record.
+
+        Note that the records returned will be tuples of of unparsed strings or
+        bytes, unless data types are specified using `set_types()`.
+        """
+        while True:
+            record = self.read_row()
+            if record is None:
+                break
+            yield record
+
+    def read_row(self) -> Optional[Tuple[Any, ...]]:
+        """
+        Read a parsed row of data from a table after a :sql:`COPY TO` operation.
+
+        Return `!None` when the data is finished.
+
+        Note that the records returned will be tuples of unparsed strings or
+        bytes, unless data types are specified using `set_types()`.
+        """
+        return self.connection.wait(self._read_row_gen())
+
     def write(self, buffer: Union[str, bytes]) -> None:
-        """Write a block of data after a :sql:`COPY FROM` operation.
+        """
+        Write a block of data to a table after a :sql:`COPY FROM` operation.
 
-        If the COPY is in binary format *buffer* must be `!bytes`. In text mode
-        it can be either `!bytes` or `!str`.
+        If the :sql:`COPY` is in binary format *buffer* must be `!bytes`. In
+        text mode it can be either `!bytes` or `!str`.
         """
         self.connection.wait(self._write_gen(buffer))
 
     def write_row(self, row: Sequence[Any]) -> None:
-        """Write a record after a :sql:`COPY FROM` operation."""
+        """Write a record to a table after a :sql:`COPY FROM` operation."""
         self.connection.wait(self._write_row_gen(row))
 
     def _finish(self, error: str = "") -> None:
@@ -225,6 +275,16 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
     async def read(self) -> bytes:
         return await self.connection.wait(self._read_gen())
 
+    async def rows(self) -> AsyncIterator[Tuple[Any, ...]]:
+        while True:
+            record = await self.read_row()
+            if record is None:
+                break
+            yield record
+
+    async def read_row(self) -> Optional[Tuple[Any, ...]]:
+        return await self.connection.wait(self._read_row_gen())
+
     async def write(self, buffer: Union[str, bytes]) -> None:
         await self.connection.wait(self._write_gen(buffer))
 
@@ -269,7 +329,7 @@ def _format_row_text(
         if item is not None:
             dumper = tx.get_dumper(item, Format.TEXT)
             b = dumper.dump(item)
-            out += _bsrepl_re.sub(_bsrepl_sub, b)
+            out += _dump_re.sub(_dump_sub, b)
         else:
             out += br"\N"
         out += b"\t"
@@ -298,8 +358,33 @@ def _format_row_binary(
     return out
 
 
+def parse_row_text(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
+    fields = data.split(b"\t")
+    fields[-1] = fields[-1][:-1]  # drop \n
+    row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields]
+    return tx.load_sequence(row)
+
+
+def parse_row_binary(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
+    row: List[Optional[bytes]] = []
+    nfields = _unpack_int2(data, 0)[0]
+    pos = 2
+    for i in range(nfields):
+        length = _unpack_int4(data, pos)[0]
+        pos += 4
+        if length >= 0:
+            row.append(data[pos : pos + length])
+            pos += length
+        else:
+            row.append(None)
+
+    return tx.load_sequence(row)
+
+
 _pack_int2 = struct.Struct("!h").pack
 _pack_int4 = struct.Struct("!i").pack
+_unpack_int2 = struct.Struct("!h").unpack_from
+_unpack_int4 = struct.Struct("!i").unpack_from
 
 _binary_signature = (
     # Signature, flags, extra length
@@ -310,23 +395,32 @@ _binary_signature = (
 _binary_trailer = b"\xff\xff"
 _binary_null = b"\xff\xff\xff\xff"
 
+_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
+_dump_repl = {
+    b"\b": b"\\b",
+    b"\t": b"\\t",
+    b"\n": b"\\n",
+    b"\v": b"\\v",
+    b"\f": b"\\f",
+    b"\r": b"\\r",
+    b"\\": b"\\\\",
+}
+
 
-def _bsrepl_sub(
-    m: Match[bytes],
-    __map: Dict[bytes, bytes] = {
-        b"\b": b"\\b",
-        b"\t": b"\\t",
-        b"\n": b"\\n",
-        b"\v": b"\\v",
-        b"\f": b"\\f",
-        b"\r": b"\\r",
-        b"\\": b"\\\\",
-    },
+def _dump_sub(
+    m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl
 ) -> bytes:
     return __map[m.group(0)]
 
 
-_bsrepl_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
+_load_re = re.compile(b"\\\\[btnvfr\\\\]")
+_load_repl = {v: k for k, v in _dump_repl.items()}
+
+
+def _load_sub(
+    m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl
+) -> bytes:
+    return __map[m.group(0)]
 
 
 # Override it with fast object if available
index f7e5d5e36c98ac55941aa20bc007f3537c36d46a..b2fecf3253204a2c9f6c9e7b08b3af91aad46522 100644 (file)
@@ -10,7 +10,8 @@ too many temporary Python objects and performing less memory copying.
 
 from cpython.ref cimport Py_INCREF
 from cpython.dict cimport PyDict_GetItem, PyDict_SetItem
-from cpython.list cimport PyList_New, PyList_GET_ITEM, PyList_SET_ITEM
+from cpython.list cimport (
+    PyList_New, PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE)
 from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM
 from cpython.object cimport PyObject, PyObject_CallFunctionObjArgs
 
@@ -324,22 +325,38 @@ cdef class Transformer:
         return record
 
     def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]:
-        cdef int length = len(record)
-        rv = PyTuple_New(length)
-        cdef RowLoader loader
+        cdef int nfields = len(record)
+        out = PyTuple_New(nfields)
+        cdef PyObject *loader  # borrowed RowLoader
+        cdef int col
+        cdef char *ptr
+        cdef Py_ssize_t size
 
-        cdef int i
-        for i in range(length):
-            item = record[i]
+        row_loaders = self._row_loaders  # avoid an incref/decref per item
+        if PyList_GET_SIZE(row_loaders) != nfields:
+            raise e.ProgrammingError(
+                f"cannot load sequence of {nfields} items:"
+                f" {len(self._row_loaders)} loaders registered")
+
+        for col in range(nfields):
+            item = record[col]
             if item is None:
-                pyval = None
+                Py_INCREF(None)
+                PyTuple_SET_ITEM(out, col, None)
+                continue
+
+            loader = PyList_GET_ITEM(row_loaders, col)
+            if (<RowLoader>loader).cloader is not None:
+                _buffer_as_string_and_size(item, &ptr, &size)
+                pyval = (<RowLoader>loader).cloader.cload(ptr, size)
             else:
-                loader = self._row_loaders[i]
-                pyval = loader.pyloader(item)
+                pyval = PyObject_CallFunctionObjArgs(
+                    (<RowLoader>loader).pyloader, <PyObject *>item, NULL)
+
             Py_INCREF(pyval)
-            PyTuple_SET_ITEM(rv, i, pyval)
+            PyTuple_SET_ITEM(out, col, pyval)
 
-        return rv
+        return out
 
     def get_loader(self, oid: int, format: Format) -> "Loader":
         return self._c_get_loader(<PyObject *>oid, <PyObject *>format)
index fccf994f32736ccf1b27599f3e8890372c2a6721..31498b74cbd46224bba8ee0ba4edd56c10dae7f4 100644 (file)
@@ -6,7 +6,9 @@ from itertools import cycle
 import pytest
 
 from psycopg3 import pq
+from psycopg3 import sql
 from psycopg3 import errors as e
+from psycopg3.oids import builtins
 from psycopg3.adapt import Format
 from psycopg3.types.numeric import Int4
 
@@ -74,6 +76,107 @@ def test_copy_out_iter(conn, format):
         assert list(copy) == want
 
 
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_read_rows(conn, format):
+    cur = conn.cursor()
+    with cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    ) as copy:
+        # TODO: should be passed by name
+        # big refactoring to be had, to have builtins not global and merged
+        # to adaptation context I guess...
+        copy.set_types(
+            [builtins["int4"].oid, builtins["int4"].oid, builtins["text"].oid]
+        )
+        rows = []
+        while 1:
+            row = copy.read_row()
+            if not row:
+                break
+            rows.append(row)
+    assert rows == sample_records
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_rows(conn, format):
+    cur = conn.cursor()
+    with cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    ) as copy:
+        copy.set_types(
+            [builtins["int4"].oid, builtins["int4"].oid, builtins["text"].oid]
+        )
+        rows = list(copy.rows())
+    assert rows == sample_records
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_out_allchars(conn, format):
+    cur = conn.cursor()
+    chars = list(map(chr, range(1, 256))) + [eur]
+    conn.client_encoding = "utf8"
+    rows = []
+    query = sql.SQL(
+        "copy (select unnest({}::text[])) to stdout (format {})"
+    ).format(chars, sql.SQL(format.name))
+    with cur.copy(query) as copy:
+        copy.set_types([builtins["text"].oid])
+        while 1:
+            row = copy.read_row()
+            if not row:
+                break
+            assert len(row) == 1
+            rows.append(row[0])
+
+    assert rows == chars
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_read_row_notypes(conn, format):
+    cur = conn.cursor()
+    with cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    ) as copy:
+        rows = []
+        while 1:
+            row = copy.read_row()
+            if not row:
+                break
+            rows.append(row)
+
+    ref = [
+        tuple(py_to_raw(i, format) for i in record)
+        for record in sample_records
+    ]
+    assert rows == ref
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_rows_notypes(conn, format):
+    cur = conn.cursor()
+    with cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    ) as copy:
+        rows = list(copy.rows())
+    ref = [
+        tuple(py_to_raw(i, format) for i in record)
+        for record in sample_records
+    ]
+    assert rows == ref
+
+
+@pytest.mark.parametrize("err", [-1, 1])
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_out_badntypes(conn, format, err):
+    cur = conn.cursor()
+    with cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    ) as copy:
+        copy.set_types([0] * (len(sample_records[0]) + err))
+        with pytest.raises(e.ProgrammingError):
+            copy.read_row()
+
+
 @pytest.mark.parametrize(
     "format, buffer",
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
@@ -366,6 +469,19 @@ def test_str(conn):
     assert "[INTRANS]" in str(copy)
 
 
+def py_to_raw(item, fmt):
+    """Convert from Python type to the expected result from the db"""
+    if fmt == Format.TEXT:
+        if isinstance(item, int):
+            return str(item)
+    else:
+        if isinstance(item, int):
+            return bytes([0, 0, 0, item])
+        elif isinstance(item, str):
+            return item.encode("utf8")
+    return item
+
+
 def ensure_table(cur, tabledef, name="copy_in"):
     cur.execute(f"drop table if exists {name}")
     cur.execute(f"create table {name} ({tabledef})")
index b1ca5ce136a6f2c10dbef3e222bc06b2d88042ba..f8999c719a909534097e0e4c5911c79502540e87 100644 (file)
@@ -6,11 +6,14 @@ from itertools import cycle
 import pytest
 
 from psycopg3 import pq
+from psycopg3 import sql
 from psycopg3 import errors as e
+from psycopg3.oids import builtins
 from psycopg3.adapt import Format
 
 from .test_copy import sample_text, sample_binary, sample_binary_rows  # noqa
 from .test_copy import eur, sample_values, sample_records, sample_tabledef
+from .test_copy import py_to_raw
 
 pytestmark = pytest.mark.asyncio
 
@@ -53,6 +56,111 @@ async def test_copy_out_iter(aconn, format):
     assert got == want
 
 
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_read_rows(aconn, format):
+    cur = await aconn.cursor()
+    async with cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    ) as copy:
+        # TODO: should be passed by name
+        # big refactoring to be had, to have builtins not global and merged
+        # to adaptation context I guess...
+        copy.set_types(
+            [builtins["int4"].oid, builtins["int4"].oid, builtins["text"].oid]
+        )
+        rows = []
+        while 1:
+            row = await copy.read_row()
+            if not row:
+                break
+            rows.append(row)
+    assert rows == sample_records
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_rows(aconn, format):
+    cur = await aconn.cursor()
+    async with cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    ) as copy:
+        copy.set_types(
+            [builtins["int4"].oid, builtins["int4"].oid, builtins["text"].oid]
+        )
+        rows = []
+        async for row in copy.rows():
+            rows.append(row)
+    assert rows == sample_records
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_out_allchars(aconn, format):
+    cur = await aconn.cursor()
+    chars = list(map(chr, range(1, 256))) + [eur]
+    await aconn.set_client_encoding("utf8")
+    rows = []
+    query = sql.SQL(
+        "copy (select unnest({}::text[])) to stdout (format {})"
+    ).format(chars, sql.SQL(format.name))
+    async with cur.copy(query) as copy:
+        copy.set_types([builtins["text"].oid])
+        while 1:
+            row = await copy.read_row()
+            if not row:
+                break
+            assert len(row) == 1
+            rows.append(row[0])
+
+    assert rows == chars
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_read_row_notypes(aconn, format):
+    cur = await aconn.cursor()
+    async with cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    ) as copy:
+        rows = []
+        while 1:
+            row = await copy.read_row()
+            if not row:
+                break
+            rows.append(row)
+
+    ref = [
+        tuple(py_to_raw(i, format) for i in record)
+        for record in sample_records
+    ]
+    assert rows == ref
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_rows_notypes(aconn, format):
+    cur = await aconn.cursor()
+    async with cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    ) as copy:
+        rows = []
+        async for row in copy.rows():
+            rows.append(row)
+    ref = [
+        tuple(py_to_raw(i, format) for i in record)
+        for record in sample_records
+    ]
+    assert rows == ref
+
+
+@pytest.mark.parametrize("err", [-1, 1])
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_out_badntypes(aconn, format, err):
+    cur = await aconn.cursor()
+    async with cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    ) as copy:
+        copy.set_types([0] * (len(sample_records[0]) + err))
+        with pytest.raises(e.ProgrammingError):
+            await copy.read_row()
+
+
 @pytest.mark.parametrize(
     "format, buffer",
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],