From 9c312a5b670c4d97575b040bfe9e791841884305 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Fri, 20 Nov 2020 02:47:44 +0000 Subject: [PATCH] COPY operations update cursor.rowcount Copy.read() changed to return b"" on EOF, consistently with file.read(). Also changed copy generators to return the final result of the operation, and pgconn.get_copy_data() to always return bytes as second argument, because it will never return an empty string unless on error. With this changeset all psycopg2 test_copy pass, both sync and async. --- psycopg3/psycopg3/copy.py | 58 +++++++---- psycopg3/psycopg3/cursor.py | 8 +- psycopg3/psycopg3/generators.py | 45 ++++----- psycopg3/psycopg3/pq/pq_ctypes.py | 4 +- psycopg3/psycopg3/pq/proto.py | 2 +- psycopg3_c/psycopg3_c/pq_cython.pyx | 4 +- tests/pq/test_copy.py | 2 +- tests/test_copy.py | 149 +++++++++++++++++++++++++++- tests/test_copy_async.py | 149 +++++++++++++++++++++++++++- 9 files changed, 358 insertions(+), 63 deletions(-) diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index 494d77196..5c7fd5161 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -11,18 +11,20 @@ from typing import Any, Dict, List, Match, Optional, Sequence, Type, Union from types import TracebackType from .pq import Format, ExecStatus -from .proto import ConnectionType, Transformer +from .proto import ConnectionType from .generators import copy_from, copy_to, copy_end if TYPE_CHECKING: from .pq.proto import PGresult + from .cursor import BaseCursor # noqa: F401 from .connection import Connection, AsyncConnection # noqa: F401 class BaseCopy(Generic[ConnectionType]): - def __init__(self, connection: ConnectionType, transformer: Transformer): - self.connection = connection - self.transformer = transformer + def __init__(self, cursor: "BaseCursor[ConnectionType]"): + self.cursor = cursor + self.connection = cursor.connection + self.transformer = cursor._transformer assert ( self.transformer.pgresult @@ -125,20 +127,24 @@ _bsrepl_re = re.compile(b"[\b\t\n\v\f\r\\\\]") class Copy(BaseCopy["Connection"]): """Manage a :sql:`COPY` operation.""" - def read(self) -> Optional[bytes]: + def read(self) -> bytes: """Read a row after a :sql:`COPY TO` operation. - Return `None` when the data is finished. + Return an empty bytes string when the data is finished. """ if self._finished: - return None + return b"" conn = self.connection - rv = conn.wait(copy_from(conn.pgconn)) - if rv is None: - self._finished = True + res = conn.wait(copy_from(conn.pgconn)) + if isinstance(res, bytes): + return res - return rv + # res is the final PGresult + self._finished = True + nrows = res.command_tuples + self.cursor._rowcount = nrows if nrows is not None else -1 + return b"" def write(self, buffer: Union[str, bytes]) -> None: """Write a block of data after a :sql:`COPY FROM` operation.""" @@ -154,10 +160,13 @@ class Copy(BaseCopy["Connection"]): """Terminate a :sql:`COPY FROM` operation.""" conn = self.connection berr = error.encode(conn.client_encoding, "replace") if error else None - conn.wait(copy_end(conn.pgconn, berr)) + res = conn.wait(copy_end(conn.pgconn, berr)) + nrows = res.command_tuples + self.cursor._rowcount = nrows if nrows is not None else -1 self._finished = True def __enter__(self) -> "Copy": + assert not self._finished return self def __exit__( @@ -183,7 +192,7 @@ class Copy(BaseCopy["Connection"]): def __iter__(self) -> Iterator[bytes]: while True: data = self.read() - if data is None: + if not data: break yield data @@ -191,16 +200,20 @@ class Copy(BaseCopy["Connection"]): class AsyncCopy(BaseCopy["AsyncConnection"]): """Manage an asynchronous :sql:`COPY` operation.""" - async def read(self) -> Optional[bytes]: + async def read(self) -> bytes: if self._finished: - return None + return b"" conn = self.connection - rv = await conn.wait(copy_from(conn.pgconn)) - if rv is None: - self._finished = True + res = await conn.wait(copy_from(conn.pgconn)) + if isinstance(res, bytes): + return res - return rv + # res is the final PGresult + self._finished = True + nrows = res.command_tuples + self.cursor._rowcount = nrows if nrows is not None else -1 + return b"" async def write(self, buffer: Union[str, bytes]) -> None: conn = self.connection @@ -213,10 +226,13 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): async def _finish(self, error: str = "") -> None: conn = self.connection berr = error.encode(conn.client_encoding, "replace") if error else None - await conn.wait(copy_end(conn.pgconn, berr)) + res = await conn.wait(copy_end(conn.pgconn, berr)) + nrows = res.command_tuples + self.cursor._rowcount = nrows if nrows is not None else -1 self._finished = True async def __aenter__(self) -> "AsyncCopy": + assert not self._finished return self async def __aexit__( @@ -242,6 +258,6 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): async def __aiter__(self) -> AsyncIterator[bytes]: while True: data = await self.read() - if data is None: + if not data: break yield data diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index f3fdddbe0..86996d900 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -162,6 +162,7 @@ class BaseCursor(Generic[ConnectionType]): ExecStatus = pq.ExecStatus _transformer: "Transformer" + _rowcount: int def __init__( self, @@ -579,7 +580,7 @@ class Cursor(BaseCursor["Connection"]): self._check_copy_results(results) self.pgresult = results[0] # will set it on the transformer too - return Copy(connection=self.connection, transformer=self._transformer) + return Copy(self) class AsyncCursor(BaseCursor["AsyncConnection"]): @@ -715,10 +716,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): self._check_copy_results(results) self.pgresult = results[0] # will set it on the transformer too - return AsyncCopy( - connection=self.connection, - transformer=self._transformer, - ) + return AsyncCopy(self) class NamedCursorMixin: diff --git a/psycopg3/psycopg3/generators.py b/psycopg3/psycopg3/generators.py index 15d2ad2bb..96944206d 100644 --- a/psycopg3/psycopg3/generators.py +++ b/psycopg3/psycopg3/generators.py @@ -16,18 +16,19 @@ when the file descriptor is ready. # Copyright (C) 2020 The Psycopg Team import logging -from typing import List, Optional +from typing import List, Optional, Union from . import pq from . import errors as e from .proto import PQGen from .waiting import Wait, Ready from .encodings import py_codecs +from .pq.proto import PGconn, PGresult logger = logging.getLogger(__name__) -def connect(conninfo: str) -> PQGen[pq.proto.PGconn]: +def connect(conninfo: str) -> PQGen[PGconn]: """ Generator to create a database connection without blocking. @@ -59,7 +60,7 @@ def connect(conninfo: str) -> PQGen[pq.proto.PGconn]: return conn -def execute(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]: +def execute(pgconn: PGconn) -> PQGen[List[PGresult]]: """ Generator sending a query and returning results without blocking. @@ -75,7 +76,7 @@ def execute(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]: return rv -def send(pgconn: pq.proto.PGconn) -> PQGen[None]: +def send(pgconn: PGconn) -> PQGen[None]: """ Generator to send a query to the server without blocking. @@ -99,7 +100,7 @@ def send(pgconn: pq.proto.PGconn) -> PQGen[None]: continue -def fetch(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]: +def fetch(pgconn: PGconn) -> PQGen[List[PGresult]]: """ Generator retrieving results from the database without blocking. @@ -110,7 +111,7 @@ def fetch(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]: or error). """ S = pq.ExecStatus - results: List[pq.proto.PGresult] = [] + results: List[PGresult] = [] while 1: pgconn.consume_input() if pgconn.is_busy(): @@ -137,7 +138,7 @@ def fetch(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]: return results -def notifies(pgconn: pq.proto.PGconn) -> PQGen[List[pq.PGnotify]]: +def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]: yield pgconn.socket, Wait.R pgconn.consume_input() @@ -152,7 +153,7 @@ def notifies(pgconn: pq.proto.PGconn) -> PQGen[List[pq.PGnotify]]: return ns -def copy_from(pgconn: pq.proto.PGconn) -> PQGen[Optional[bytes]]: +def copy_from(pgconn: PGconn) -> PQGen[Union[bytes, PGresult]]: while 1: nbytes, data = pgconn.get_copy_data(1) if nbytes != 0: @@ -167,27 +168,23 @@ def copy_from(pgconn: pq.proto.PGconn) -> PQGen[Optional[bytes]]: return data # Retrieve the final result of copy - results = yield from fetch(pgconn) - if len(results) != 1: - raise e.InternalError( - f"1 result expected from copy end, got {len(results)}" - ) - if results[0].status != pq.ExecStatus.COMMAND_OK: + (result,) = yield from fetch(pgconn) + if result.status != pq.ExecStatus.COMMAND_OK: encoding = py_codecs.get( pgconn.parameter_status(b"client_encoding") or "", "utf-8" ) - raise e.error_from_result(results[0], encoding=encoding) + raise e.error_from_result(result, encoding=encoding) - return None + return result -def copy_to(pgconn: pq.proto.PGconn, buffer: bytes) -> PQGen[None]: +def copy_to(pgconn: PGconn, buffer: bytes) -> PQGen[None]: # Retry enqueuing data until successful while pgconn.put_copy_data(buffer) == 0: yield pgconn.socket, Wait.W -def copy_end(pgconn: pq.proto.PGconn, error: Optional[bytes]) -> PQGen[None]: +def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]: # Retry enqueuing end copy message until successful while pgconn.put_copy_end(error) == 0: yield pgconn.socket, Wait.W @@ -200,13 +197,11 @@ def copy_end(pgconn: pq.proto.PGconn, error: Optional[bytes]) -> PQGen[None]: break # Retrieve the final result of copy - results = yield from fetch(pgconn) - if len(results) != 1: - raise e.InternalError( - f"1 result expected from copy end, got {len(results)}" - ) - if results[0].status != pq.ExecStatus.COMMAND_OK: + (result,) = yield from fetch(pgconn) + if result.status != pq.ExecStatus.COMMAND_OK: encoding = py_codecs.get( pgconn.parameter_status(b"client_encoding") or "", "utf-8" ) - raise e.error_from_result(results[0], encoding=encoding) + raise e.error_from_result(result, encoding=encoding) + + return result diff --git a/psycopg3/psycopg3/pq/pq_ctypes.py b/psycopg3/psycopg3/pq/pq_ctypes.py index 70cdedd75..5532540b0 100644 --- a/psycopg3/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/psycopg3/pq/pq_ctypes.py @@ -518,7 +518,7 @@ class PGconn: raise PQerror(f"sending copy end failed: {error_message(self)}") return rv - def get_copy_data(self, async_: int) -> Tuple[int, Optional[bytes]]: + def get_copy_data(self, async_: int) -> Tuple[int, bytes]: buffer_ptr = c_char_p() nbytes = impl.PQgetCopyData(self.pgconn_ptr, byref(buffer_ptr), async_) if nbytes == -2: @@ -529,7 +529,7 @@ class PGconn: impl.PQfreemem(buffer_ptr) return nbytes, data else: - return nbytes, None + return nbytes, b"" def make_empty_result(self, exec_status: ExecStatus) -> "PGresult": rv = impl.PQmakeEmptyPGresult(self.pgconn_ptr, exec_status) diff --git a/psycopg3/psycopg3/pq/proto.py b/psycopg3/psycopg3/pq/proto.py index a8b4cd655..18a50c96f 100644 --- a/psycopg3/psycopg3/pq/proto.py +++ b/psycopg3/psycopg3/pq/proto.py @@ -231,7 +231,7 @@ class PGconn(Protocol): def put_copy_end(self, error: Optional[bytes] = None) -> int: ... - def get_copy_data(self, async_: int) -> Tuple[int, Optional[bytes]]: + def get_copy_data(self, async_: int) -> Tuple[int, bytes]: ... def make_empty_result(self, exec_status: ExecStatus) -> "PGresult": diff --git a/psycopg3_c/psycopg3_c/pq_cython.pyx b/psycopg3_c/psycopg3_c/pq_cython.pyx index 567c3d8c8..c310f5a04 100644 --- a/psycopg3_c/psycopg3_c/pq_cython.pyx +++ b/psycopg3_c/psycopg3_c/pq_cython.pyx @@ -447,7 +447,7 @@ cdef class PGconn: raise PQerror(f"sending copy end failed: {error_message(self)}") return rv - def get_copy_data(self, async_: int) -> Tuple[int, Optional[bytes]]: + def get_copy_data(self, async_: int) -> Tuple[int, bytes]: cdef char *buffer_ptr = NULL cdef int nbytes nbytes = impl.PQgetCopyData(self.pgconn_ptr, &buffer_ptr, async_) @@ -459,7 +459,7 @@ cdef class PGconn: impl.PQfreemem(buffer_ptr) return nbytes, data else: - return nbytes, None + return nbytes, b"" def make_empty_result(self, exec_status: ExecStatus) -> PGresult: cdef impl.PGresult *rv = impl.PQmakeEmptyPGresult( diff --git a/tests/pq/test_copy.py b/tests/pq/test_copy.py index db0c641d6..41a35f176 100644 --- a/tests/pq/test_copy.py +++ b/tests/pq/test_copy.py @@ -160,7 +160,7 @@ def test_copy_out_read(pgconn, format): assert nbytes == len(data) assert data == row - assert pgconn.get_copy_data(0) == (-1, None) + assert pgconn.get_copy_data(0) == (-1, b"") res = pgconn.get_result() assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message diff --git a/tests/test_copy.py b/tests/test_copy.py index 6596d71ac..a38d156fb 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -1,3 +1,8 @@ +import string +import hashlib +from io import BytesIO, StringIO +from itertools import cycle + import pytest from psycopg3 import pq @@ -50,10 +55,10 @@ def test_copy_out_read(conn, format): got = copy.read() assert got == row - assert copy.read() is None - assert copy.read() is None + assert copy.read() == b"" + assert copy.read() == b"" - assert copy.read() is None + assert copy.read() == b"" @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) @@ -201,6 +206,144 @@ from copy_in group by 1, 2, 3 assert data == [(True, True, 1, 256)] +@pytest.mark.slow +def test_copy_from_to(conn): + # Roundtrip from file to database to file blockwise + gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024) + gen.ensure_table() + cur = conn.cursor() + with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + copy.write(block) + + gen.assert_data() + + f = StringIO() + with cur.copy("copy copy_in to stdout") as copy: + for block in copy: + f.write(block.decode("utf8")) + + f.seek(0) + assert gen.sha(f) == gen.sha(gen.file()) + + +@pytest.mark.slow +def test_copy_from_to_bytes(conn): + # Roundtrip from file to database to file blockwise + gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024) + gen.ensure_table() + cur = conn.cursor() + with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + copy.write(block.encode("utf8")) + + gen.assert_data() + + f = BytesIO() + with cur.copy("copy copy_in to stdout") as copy: + for block in copy: + f.write(block) + + f.seek(0) + assert gen.sha(f) == gen.sha(gen.file()) + + +@pytest.mark.slow +def test_copy_from_insane_size(conn): + # Trying to trigger a "would block" error + gen = DataGenerator( + conn, nrecs=4 * 1024, srec=10 * 1024, block_size=20 * 1024 * 1024 + ) + gen.ensure_table() + cur = conn.cursor() + with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + copy.write(block) + + gen.assert_data() + + +def test_copy_rowcount(conn): + gen = DataGenerator(conn, nrecs=3, srec=10) + gen.ensure_table() + + cur = conn.cursor() + with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + copy.write(block) + assert cur.rowcount == 3 + + gen = DataGenerator(conn, nrecs=2, srec=10, offset=3) + with cur.copy("copy copy_in from stdin") as copy: + for rec in gen.records(): + copy.write_row(rec) + assert cur.rowcount == 2 + + with cur.copy("copy copy_in to stdout") as copy: + for block in copy: + pass + assert cur.rowcount == 5 + + with pytest.raises(e.BadCopyFileFormat): + with cur.copy("copy copy_in (id) from stdin") as copy: + for rec in gen.records(): + copy.write_row(rec) + assert cur.rowcount == -1 + + def ensure_table(cur, tabledef, name="copy_in"): cur.execute(f"drop table if exists {name}") cur.execute(f"create table {name} ({tabledef})") + + +class DataGenerator: + def __init__(self, conn, nrecs, srec, offset=0, block_size=8192): + self.conn = conn + self.nrecs = nrecs + self.srec = srec + self.offset = offset + self.block_size = block_size + + def ensure_table(self): + cur = self.conn.cursor() + ensure_table(cur, "id integer primary key, data text") + + def records(self): + for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)): + s = c * self.srec + yield (i + self.offset, s) + + def file(self): + f = StringIO() + for i, s in self.records(): + f.write("%s\t%s\n" % (i, s)) + + f.seek(0) + return f + + def blocks(self): + f = self.file() + while True: + block = f.read(self.block_size) + if not block: + break + yield block + + def assert_data(self): + cur = self.conn.cursor() + cur.execute("select id, data from copy_in order by id") + for record in self.records(): + assert record == cur.fetchone() + + assert cur.fetchone() is None + + def sha(self, f): + m = hashlib.sha256() + while 1: + block = f.read() + if not block: + break + if isinstance(block, str): + block = block.encode("utf8") + m.update(block) + return m.hexdigest() diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 2a471999b..da491a4bb 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -1,3 +1,8 @@ +import string +import hashlib +from io import BytesIO, StringIO +from itertools import cycle + import pytest from psycopg3 import pq @@ -25,10 +30,10 @@ async def test_copy_out_read(aconn, format): got = await copy.read() assert got == row - assert await copy.read() is None - assert await copy.read() is None + assert await copy.read() == b"" + assert await copy.read() == b"" - assert await copy.read() is None + assert await copy.read() == b"" @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) @@ -189,6 +194,144 @@ from copy_in group by 1, 2, 3 assert data == [(True, True, 1, 256)] +@pytest.mark.slow +async def test_copy_from_to(aconn): + # Roundtrip from file to database to file blockwise + gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024) + await gen.ensure_table() + cur = await aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(block) + + await gen.assert_data() + + f = StringIO() + async with cur.copy("copy copy_in to stdout") as copy: + async for block in copy: + f.write(block.decode("utf8")) + + f.seek(0) + assert gen.sha(f) == gen.sha(gen.file()) + + +@pytest.mark.slow +async def test_copy_from_to_bytes(aconn): + # Roundtrip from file to database to file blockwise + gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024) + await gen.ensure_table() + cur = await aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(block.encode("utf8")) + + await gen.assert_data() + + f = BytesIO() + async with cur.copy("copy copy_in to stdout") as copy: + async for block in copy: + f.write(block) + + f.seek(0) + assert gen.sha(f) == gen.sha(gen.file()) + + +@pytest.mark.slow +async def test_copy_from_insane_size(aconn): + # Trying to trigger a "would block" error + gen = DataGenerator( + aconn, nrecs=4 * 1024, srec=10 * 1024, block_size=20 * 1024 * 1024 + ) + await gen.ensure_table() + cur = await aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(block) + + await gen.assert_data() + + +async def test_copy_rowcount(aconn): + gen = DataGenerator(aconn, nrecs=3, srec=10) + await gen.ensure_table() + + cur = await aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(block) + assert cur.rowcount == 3 + + gen = DataGenerator(aconn, nrecs=2, srec=10, offset=3) + async with cur.copy("copy copy_in from stdin") as copy: + for rec in gen.records(): + await copy.write_row(rec) + assert cur.rowcount == 2 + + async with cur.copy("copy copy_in to stdout") as copy: + async for block in copy: + pass + assert cur.rowcount == 5 + + with pytest.raises(e.BadCopyFileFormat): + async with cur.copy("copy copy_in (id) from stdin") as copy: + for rec in gen.records(): + await copy.write_row(rec) + assert cur.rowcount == -1 + + async def ensure_table(cur, tabledef, name="copy_in"): await cur.execute(f"drop table if exists {name}") await cur.execute(f"create table {name} ({tabledef})") + + +class DataGenerator: + def __init__(self, conn, nrecs, srec, offset=0, block_size=8192): + self.conn = conn + self.nrecs = nrecs + self.srec = srec + self.offset = offset + self.block_size = block_size + + async def ensure_table(self): + cur = await self.conn.cursor() + await ensure_table(cur, "id integer primary key, data text") + + def records(self): + for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)): + s = c * self.srec + yield (i + self.offset, s) + + def file(self): + f = StringIO() + for i, s in self.records(): + f.write("%s\t%s\n" % (i, s)) + + f.seek(0) + return f + + def blocks(self): + f = self.file() + while True: + block = f.read(self.block_size) + if not block: + break + yield block + + async def assert_data(self): + cur = await self.conn.cursor() + await cur.execute("select id, data from copy_in order by id") + for record in self.records(): + assert record == await cur.fetchone() + + assert await cur.fetchone() is None + + def sha(self, f): + m = hashlib.sha256() + while 1: + block = f.read() + if not block: + break + if isinstance(block, str): + block = block.encode("utf8") + m.update(block) + return m.hexdigest() -- 2.47.2