From: Daniele Varrazzo Date: Sat, 9 Jan 2021 03:45:08 +0000 (+0100) Subject: Added reading row-by-row from Copy X-Git-Tag: 3.0.dev0~191 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ff462961ea4980e0fe5fb88d1aafa85ab9c9b1e7;p=thirdparty%2Fpsycopg.git Added reading row-by-row from Copy --- diff --git a/docs/copy.rst b/docs/copy.rst index dcbe18722..95f8551d0 100644 --- a/docs/copy.rst +++ b/docs/copy.rst @@ -1,3 +1,5 @@ +.. currentmodule:: psycopg3 + .. index:: pair: COPY; SQL command @@ -12,41 +14,128 @@ it, with some SQL creativity). .. __: https://www.postgresql.org/docs/current/sql-copy.html -Using `!psycopg3` you can do three things: +Copy is supported using the `Cursor.copy()` method, passing it a query of the +form :sql:`COPY ... FROM STDIN` or :sql:`COPY ... TO STDOUT`, and managing the +resulting `Copy` object in a ``with`` block: -- loading data into the database row-by-row, from a stream of Python objects; -- loading data into the database block-by-block, with data already formatted in - a way suitable for :sql:`COPY FROM`; -- reading data from the database block-by-block, with data emitted by a - :sql:`COPY TO` statement. +.. code:: python -The missing quadrant, copying data from the database row-by-row, is not -covered by COPY because that's pretty much normal querying, and :sql:`COPY TO` -doesn't offer enough metadata to decode the data to Python objects. + with cursor.copy("COPY table_name (col1, col2) FROM STDIN") as copy: + # pass data to the 'copy' object using write()/write_row() -The first option is the most powerful, because it allows to load data into the -database from any Python iterable (a list of tuple, or any iterable of -sequences): the Python values are adapted as they would be in normal querying. -To perform such operation use a :sql:`COPY [table] FROM STDIN` with -`Cursor.copy()` and use `~Copy.write_row()` on the resulting object in a -``with`` block. On exiting the block the operation will be concluded: +You can compose a dynamically a COPY statement by using objects from the +`psycopg3.sql` module: .. code:: python - with cursor.copy("COPY table_name (col1, col2) FROM STDIN") as copy: - for row in source: - copy.write_row(row) + with cursor.copy( + sql.SQL("COPY {} TO STDOUT").format(sql.Identifier("table_name")) + ) as copy: + # read data from the 'copy' object using read()/read_row() + +The connection is subject to the usual transaction behaviour, so, unless the +connection is in autocommit, at the end of the COPY operation you will still +have to commit the pending changes and you can still roll them back. See +:ref:`transactions` for details. + + +Writing data row-by-row +----------------------- + +Using a copy operation you can load data into the database from any Python +iterable (a list of tuple, or any iterable of sequences): the Python values +are adapted as they would be in normal querying. To perform such operation use +a :sql:`COPY ... FROM STDIN` with `Cursor.copy()` and use `~Copy.write_row()` +on the resulting object in a ``with`` block. On exiting the block the +operation will be concluded: + +.. code:: python + + records = [(10, 20, "hello"), (40, None, "world")] + + with cursor.copy("COPY sample (col1, col2, col3) FROM STDIN") as copy: + for record in records: + copy.write_row(record) If an exception is raised inside the block, the operation is interrupted and -the records inserted so far discarded. +the records inserted so far are discarded. + +In order to read or write from `!Copy` row-by-row you must not specify +:sql:`COPY` options such as :sql:`FORMAT CSV`, :sql:`DELIMITER`, :sql:`NULL`: +please leave these details alone, thank you :) + +Binary copy is supported by specifying :sql:`FORMAT BINARY` in the :sql:`COPY` +statement. In order to load binary data, all the types passed to the database +must have a binary dumper registered (see see :ref:`binary-data`). + +Note that PostgreSQL is particularly finicky when loading data in binary mode +and will apply *no cast rule*. This means that e.g. passing a Python `!int` +object to an :sql:`integer` column (aka :sql:`int4`) will likely fail, because +the default `!int` `~adapt.Dumper` will use the :sql:`bigint` aka :sql:`int8` +format. You can work around the problem by registering the right binary dumper +on the cursor or using the right data wrapper (see :ref:`adaptation`). + + +Reading data row-by-row +----------------------- + +You can also do the opposite, reading rows out of a :sql:`COPY ... TO STDOUT` +operation, by iterating on `~Copy.rows()`. However this is not something you +may want to do normally: usually the normal query process will be easier to +use. + +PostgreSQL, currently, doesn't give complete type information on :sql:`COPY +TO`, so the rows returned will have unparsed data, as strings or bytes, +according to the format. + +.. code:: python + + with cur.copy("COPY (VALUES (10::int, current_date)) TO STDOUT") as copy: + for row in copy.rows(): + print(row) # return unparsed data: ('10', '2046-12-24') + +You can improve the results by using `~Copy.set_types()` before reading, but +you have to specify them yourselves. + +.. code:: python + + from psycopg3.oids import builtins + + with cur.copy("COPY (VALUES (10::int, current_date)) TO STDOUT") as copy: + copy.set_types([builtins["int4"].oid, builtins["date"].oid]) + for row in copy.rows(): + print(row) # (10, datetime.date(2046, 12, 24)) + +.. admonition:: TODO + + Document the `!builtins` register... but more likely do something + better such as allowing to pass type names, unifying `TypeRegistry` and + `AdaptContext`, none of which I have documented, so you haven't seen + anything... 👀 + + +Copying block-by-block +---------------------- If data is already formatted in a way suitable for copy (for instance because it is coming from a file resulting from a previous `COPY TO` operation) it can -be loaded using `Copy.write()` instead. +be loaded into the database using `Copy.write()` instead. + +.. code:: python -In order to read data in :sql:`COPY` format you can use a :sql:`COPY TO + with open("data", "r") as f: + with cursor.copy("COPY data FROM STDIN") as copy: + while data := f.read(BLOCK_SIZE): + copy.write(data) + +In this case you can use any :sql:`COPY` option and format, as long as the +input data is compatible. Data can be passed as `!str`, if the copy is in +:sql:`FORMAT TEXT`, or as `!bytes`, which works with both :sql:`FORMAT TEXT` +and :sql:`FORMAT BINARY`. + +In order to produce data in :sql:`COPY` format you can use a :sql:`COPY ... TO STDOUT` statement and iterate over the resulting `Copy` object, which will -produce `!bytes`: +produce a stream of `!bytes`: .. code:: python @@ -55,16 +144,20 @@ produce `!bytes`: for data in copy: f.write(data) -Asynchronous operations are supported using the same patterns on an -`AsyncConnection`. For instance, if `!f` is an object supporting an -asynchronous `!read()` method returning :sql:`COPY` data, a fully-async copy -operation could be: + +Asynchronous copy support +------------------------- + +Asynchronous operations are supported using the same patterns as above, using +the objects obtained by an `AsyncConnection`. For instance, if `!f` is an +object supporting an asynchronous `!read()` method returning :sql:`COPY` data, +a fully-async copy operation could be: .. code:: python async with cursor.copy("COPY data FROM STDIN") as copy: - while data := await f.read() + while data := await f.read(): await copy.write(data) -Binary data can be produced and consumed using :sql:`FORMAT BINARY` in the -:sql:`COPY` command: see :ref:`binary-data` for details and limitations. +The `AsyncCopy` object documentation describe the signature of the +asynchronous methods and the differences from its sync `Copy` counterpart. diff --git a/docs/cursor.rst b/docs/cursor.rst index 3dce83733..6f5368dc3 100644 --- a/docs/cursor.rst +++ b/docs/cursor.rst @@ -187,28 +187,37 @@ Cursor support objects See :ref:`copy` for details. + .. automethod:: write_row + + The data in the tuple will be converted as configured on the cursor; + see :ref:`adaptation` for details. + + .. automethod:: write .. automethod:: read Instead of using `!read()` you can even iterate on the object to read its data row by row, using ``for row in copy: ...``. - .. automethod:: write - .. automethod:: write_row - - The data in the tuple will be converted as configured on the cursor; - see :ref:`adaptation` for details. + .. automethod:: rows + .. automethod:: read_row + .. automethod:: set_types .. autoclass:: AsyncCopy() - The object is normally returned by ``async with`` `AsyncCursor.copy()`. Its methods are - the same of the `Copy` object but offering an `asyncio` interface - (`await`, `async for`, `async with`). + The object is normally returned by ``async with`` `AsyncCursor.copy()`. + Its methods are similar to the ones of the `Copy` object but offering an + `asyncio` interface (`await`, `async for`, `async with`). + .. automethod:: write_row + .. automethod:: write .. automethod:: read Instead of using `!read()` you can even iterate on the object to read its data row by row, using ``async for row in copy: ...``. - .. automethod:: write - .. automethod:: write_row + .. automethod:: rows + + Use it as `async for record in copy.rows():` ... + + .. automethod:: read_row diff --git a/psycopg3/psycopg3/_transform.py b/psycopg3/psycopg3/_transform.py index 1f792f563..cbc6ad00e 100644 --- a/psycopg3/psycopg3/_transform.py +++ b/psycopg3/psycopg3/_transform.py @@ -172,6 +172,12 @@ class Transformer(AdaptContext): def load_sequence( self, record: Sequence[Optional[bytes]] ) -> Tuple[Any, ...]: + if len(self._row_loaders) != len(record): + raise e.ProgrammingError( + f"cannot load sequence of {len(record)} items:" + f" {len(self._row_loaders)} loaders registered" + ) + return tuple( (self._row_loaders[i](val) if val is not None else None) for i, val in enumerate(record) diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index 33481fd9a..0d91961d9 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -7,8 +7,8 @@ psycopg3 copy support import re import struct from types import TracebackType -from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic -from typing import Any, Dict, Match, Optional, Sequence, Type, Union +from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union +from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple from typing_extensions import Protocol from . import pq @@ -57,14 +57,25 @@ class BaseCopy(Generic[ConnectionType]): self._format_row: FormatFunc if self.format == Format.TEXT: self._format_row = format_row_text + self._parse_row = parse_row_text else: self._format_row = format_row_binary + self._parse_row = parse_row_binary def __repr__(self) -> str: cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" info = pq.misc.connection_summary(self._pgconn) return f"<{cls} {info} at 0x{id(self):x}>" + def set_types(self, types: Sequence[int]) -> None: + """ + Set the types expected out of a :sql:`COPY TO` operation. + + Without setting the types, the data from :sql:`COPY TO` will be + returned as unparsed strings or bytes. + """ + self.transformer.set_row_types(types, [self.format] * len(types)) + # High level copy protocol generators (state change of the Copy object) def _read_gen(self) -> PQGen[bytes]: @@ -81,6 +92,19 @@ class BaseCopy(Generic[ConnectionType]): self.cursor._rowcount = nrows if nrows is not None else -1 return b"" + def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]: + data = yield from self._read_gen() + if not data: + return None + if self.format == Format.BINARY: + if not self._signature_sent: + assert data.startswith(_binary_signature) + self._signature_sent = True + data = data[len(_binary_signature) :] + elif data == _binary_trailer: + return None + return self._parse_row(data, self.transformer) + def _write_gen(self, buffer: Union[str, bytes]) -> PQGen[None]: # if write() was called, assume the header was sent together with the # first block of data. @@ -175,22 +199,48 @@ class Copy(BaseCopy["Connection"]): __module__ = "psycopg3" def read(self) -> bytes: - """Read a row of data after a :sql:`COPY TO` operation. + """ + Read an unparsed row from a table after a :sql:`COPY TO` operation. Return an empty bytes string when the data is finished. """ return self.connection.wait(self._read_gen()) + def rows(self) -> Iterator[Tuple[Any, ...]]: + """ + Iterate on the result of a :sql:`COPY TO` operation record by record. + + Note that the records returned will be tuples of of unparsed strings or + bytes, unless data types are specified using `set_types()`. + """ + while True: + record = self.read_row() + if record is None: + break + yield record + + def read_row(self) -> Optional[Tuple[Any, ...]]: + """ + Read a parsed row of data from a table after a :sql:`COPY TO` operation. + + Return `!None` when the data is finished. + + Note that the records returned will be tuples of unparsed strings or + bytes, unless data types are specified using `set_types()`. + """ + return self.connection.wait(self._read_row_gen()) + def write(self, buffer: Union[str, bytes]) -> None: - """Write a block of data after a :sql:`COPY FROM` operation. + """ + Write a block of data to a table after a :sql:`COPY FROM` operation. - If the COPY is in binary format *buffer* must be `!bytes`. In text mode - it can be either `!bytes` or `!str`. + If the :sql:`COPY` is in binary format *buffer* must be `!bytes`. In + text mode it can be either `!bytes` or `!str`. """ self.connection.wait(self._write_gen(buffer)) def write_row(self, row: Sequence[Any]) -> None: - """Write a record after a :sql:`COPY FROM` operation.""" + """Write a record to a table after a :sql:`COPY FROM` operation.""" self.connection.wait(self._write_row_gen(row)) def _finish(self, error: str = "") -> None: @@ -225,6 +275,16 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): async def read(self) -> bytes: return await self.connection.wait(self._read_gen()) + async def rows(self) -> AsyncIterator[Tuple[Any, ...]]: + while True: + record = await self.read_row() + if record is None: + break + yield record + + async def read_row(self) -> Optional[Tuple[Any, ...]]: + return await self.connection.wait(self._read_row_gen()) + async def write(self, buffer: Union[str, bytes]) -> None: await self.connection.wait(self._write_gen(buffer)) @@ -269,7 +329,7 @@ def _format_row_text( if item is not None: dumper = tx.get_dumper(item, Format.TEXT) b = dumper.dump(item) - out += _bsrepl_re.sub(_bsrepl_sub, b) + out += _dump_re.sub(_dump_sub, b) else: out += br"\N" out += b"\t" @@ -298,8 +358,33 @@ def _format_row_binary( return out +def parse_row_text(data: bytes, tx: Transformer) -> Tuple[Any, ...]: + fields = data.split(b"\t") + fields[-1] = fields[-1][:-1] # drop \n + row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields] + return tx.load_sequence(row) + + +def parse_row_binary(data: bytes, tx: Transformer) -> Tuple[Any, ...]: + row: List[Optional[bytes]] = [] + nfields = _unpack_int2(data, 0)[0] + pos = 2 + for i in range(nfields): + length = _unpack_int4(data, pos)[0] + pos += 4 + if length >= 0: + row.append(data[pos : pos + length]) + pos += length + else: + row.append(None) + + return tx.load_sequence(row) + + _pack_int2 = struct.Struct("!h").pack _pack_int4 = struct.Struct("!i").pack +_unpack_int2 = struct.Struct("!h").unpack_from +_unpack_int4 = struct.Struct("!i").unpack_from _binary_signature = ( # Signature, flags, extra length @@ -310,23 +395,32 @@ _binary_signature = ( _binary_trailer = b"\xff\xff" _binary_null = b"\xff\xff\xff\xff" +_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]") +_dump_repl = { + b"\b": b"\\b", + b"\t": b"\\t", + b"\n": b"\\n", + b"\v": b"\\v", + b"\f": b"\\f", + b"\r": b"\\r", + b"\\": b"\\\\", +} + -def _bsrepl_sub( - m: Match[bytes], - __map: Dict[bytes, bytes] = { - b"\b": b"\\b", - b"\t": b"\\t", - b"\n": b"\\n", - b"\v": b"\\v", - b"\f": b"\\f", - b"\r": b"\\r", - b"\\": b"\\\\", - }, +def _dump_sub( + m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl ) -> bytes: return __map[m.group(0)] -_bsrepl_re = re.compile(b"[\b\t\n\v\f\r\\\\]") +_load_re = re.compile(b"\\\\[btnvfr\\\\]") +_load_repl = {v: k for k, v in _dump_repl.items()} + + +def _load_sub( + m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl +) -> bytes: + return __map[m.group(0)] # Override it with fast object if available diff --git a/psycopg3_c/psycopg3_c/_psycopg3/transform.pyx b/psycopg3_c/psycopg3_c/_psycopg3/transform.pyx index f7e5d5e36..b2fecf325 100644 --- a/psycopg3_c/psycopg3_c/_psycopg3/transform.pyx +++ b/psycopg3_c/psycopg3_c/_psycopg3/transform.pyx @@ -10,7 +10,8 @@ too many temporary Python objects and performing less memory copying. from cpython.ref cimport Py_INCREF from cpython.dict cimport PyDict_GetItem, PyDict_SetItem -from cpython.list cimport PyList_New, PyList_GET_ITEM, PyList_SET_ITEM +from cpython.list cimport ( + PyList_New, PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE) from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM from cpython.object cimport PyObject, PyObject_CallFunctionObjArgs @@ -324,22 +325,38 @@ cdef class Transformer: return record def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]: - cdef int length = len(record) - rv = PyTuple_New(length) - cdef RowLoader loader + cdef int nfields = len(record) + out = PyTuple_New(nfields) + cdef PyObject *loader # borrowed RowLoader + cdef int col + cdef char *ptr + cdef Py_ssize_t size - cdef int i - for i in range(length): - item = record[i] + row_loaders = self._row_loaders # avoid an incref/decref per item + if PyList_GET_SIZE(row_loaders) != nfields: + raise e.ProgrammingError( + f"cannot load sequence of {nfields} items:" + f" {len(self._row_loaders)} loaders registered") + + for col in range(nfields): + item = record[col] if item is None: - pyval = None + Py_INCREF(None) + PyTuple_SET_ITEM(out, col, None) + continue + + loader = PyList_GET_ITEM(row_loaders, col) + if (loader).cloader is not None: + _buffer_as_string_and_size(item, &ptr, &size) + pyval = (loader).cloader.cload(ptr, size) else: - loader = self._row_loaders[i] - pyval = loader.pyloader(item) + pyval = PyObject_CallFunctionObjArgs( + (loader).pyloader, item, NULL) + Py_INCREF(pyval) - PyTuple_SET_ITEM(rv, i, pyval) + PyTuple_SET_ITEM(out, col, pyval) - return rv + return out def get_loader(self, oid: int, format: Format) -> "Loader": return self._c_get_loader(oid, format) diff --git a/tests/test_copy.py b/tests/test_copy.py index fccf994f3..31498b74c 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -6,7 +6,9 @@ from itertools import cycle import pytest from psycopg3 import pq +from psycopg3 import sql from psycopg3 import errors as e +from psycopg3.oids import builtins from psycopg3.adapt import Format from psycopg3.types.numeric import Int4 @@ -74,6 +76,107 @@ def test_copy_out_iter(conn, format): assert list(copy) == want +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_read_rows(conn, format): + cur = conn.cursor() + with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + # TODO: should be passed by name + # big refactoring to be had, to have builtins not global and merged + # to adaptation context I guess... + copy.set_types( + [builtins["int4"].oid, builtins["int4"].oid, builtins["text"].oid] + ) + rows = [] + while 1: + row = copy.read_row() + if not row: + break + rows.append(row) + assert rows == sample_records + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_rows(conn, format): + cur = conn.cursor() + with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + copy.set_types( + [builtins["int4"].oid, builtins["int4"].oid, builtins["text"].oid] + ) + rows = list(copy.rows()) + assert rows == sample_records + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_copy_out_allchars(conn, format): + cur = conn.cursor() + chars = list(map(chr, range(1, 256))) + [eur] + conn.client_encoding = "utf8" + rows = [] + query = sql.SQL( + "copy (select unnest({}::text[])) to stdout (format {})" + ).format(chars, sql.SQL(format.name)) + with cur.copy(query) as copy: + copy.set_types([builtins["text"].oid]) + while 1: + row = copy.read_row() + if not row: + break + assert len(row) == 1 + rows.append(row[0]) + + assert rows == chars + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_read_row_notypes(conn, format): + cur = conn.cursor() + with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + rows = [] + while 1: + row = copy.read_row() + if not row: + break + rows.append(row) + + ref = [ + tuple(py_to_raw(i, format) for i in record) + for record in sample_records + ] + assert rows == ref + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_rows_notypes(conn, format): + cur = conn.cursor() + with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + rows = list(copy.rows()) + ref = [ + tuple(py_to_raw(i, format) for i in record) + for record in sample_records + ] + assert rows == ref + + +@pytest.mark.parametrize("err", [-1, 1]) +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_copy_out_badntypes(conn, format, err): + cur = conn.cursor() + with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + copy.set_types([0] * (len(sample_records[0]) + err)) + with pytest.raises(e.ProgrammingError): + copy.read_row() + + @pytest.mark.parametrize( "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], @@ -366,6 +469,19 @@ def test_str(conn): assert "[INTRANS]" in str(copy) +def py_to_raw(item, fmt): + """Convert from Python type to the expected result from the db""" + if fmt == Format.TEXT: + if isinstance(item, int): + return str(item) + else: + if isinstance(item, int): + return bytes([0, 0, 0, item]) + elif isinstance(item, str): + return item.encode("utf8") + return item + + def ensure_table(cur, tabledef, name="copy_in"): cur.execute(f"drop table if exists {name}") cur.execute(f"create table {name} ({tabledef})") diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index b1ca5ce13..f8999c719 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -6,11 +6,14 @@ from itertools import cycle import pytest from psycopg3 import pq +from psycopg3 import sql from psycopg3 import errors as e +from psycopg3.oids import builtins from psycopg3.adapt import Format from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa from .test_copy import eur, sample_values, sample_records, sample_tabledef +from .test_copy import py_to_raw pytestmark = pytest.mark.asyncio @@ -53,6 +56,111 @@ async def test_copy_out_iter(aconn, format): assert got == want +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_read_rows(aconn, format): + cur = await aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + # TODO: should be passed by name + # big refactoring to be had, to have builtins not global and merged + # to adaptation context I guess... + copy.set_types( + [builtins["int4"].oid, builtins["int4"].oid, builtins["text"].oid] + ) + rows = [] + while 1: + row = await copy.read_row() + if not row: + break + rows.append(row) + assert rows == sample_records + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_rows(aconn, format): + cur = await aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + copy.set_types( + [builtins["int4"].oid, builtins["int4"].oid, builtins["text"].oid] + ) + rows = [] + async for row in copy.rows(): + rows.append(row) + assert rows == sample_records + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_copy_out_allchars(aconn, format): + cur = await aconn.cursor() + chars = list(map(chr, range(1, 256))) + [eur] + await aconn.set_client_encoding("utf8") + rows = [] + query = sql.SQL( + "copy (select unnest({}::text[])) to stdout (format {})" + ).format(chars, sql.SQL(format.name)) + async with cur.copy(query) as copy: + copy.set_types([builtins["text"].oid]) + while 1: + row = await copy.read_row() + if not row: + break + assert len(row) == 1 + rows.append(row[0]) + + assert rows == chars + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_read_row_notypes(aconn, format): + cur = await aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + rows = [] + while 1: + row = await copy.read_row() + if not row: + break + rows.append(row) + + ref = [ + tuple(py_to_raw(i, format) for i in record) + for record in sample_records + ] + assert rows == ref + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_rows_notypes(aconn, format): + cur = await aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + rows = [] + async for row in copy.rows(): + rows.append(row) + ref = [ + tuple(py_to_raw(i, format) for i in record) + for record in sample_records + ] + assert rows == ref + + +@pytest.mark.parametrize("err", [-1, 1]) +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_copy_out_badntypes(aconn, format, err): + cur = await aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + copy.set_types([0] * (len(sample_records[0]) + err)) + with pytest.raises(e.ProgrammingError): + await copy.read_row() + + @pytest.mark.parametrize( "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],