From: Daniele Varrazzo Date: Sat, 11 Jun 2022 22:47:50 +0000 (+0200) Subject: refactor(copy): make the writer entirely responsible of the libpq copy state X-Git-Tag: 3.1~44^2~9 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b77bd8c9a37fa9b83a3ce099584e798b36ec46ee;p=thirdparty%2Fpsycopg.git refactor(copy): make the writer entirely responsible of the libpq copy state This way we can create a writer which is entirely independent from the libpq and the connection, which is useful, for instance, to format a file with copy data. --- diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index b9d641921..6ac502500 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -15,28 +15,28 @@ from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING from . import pq +from . import adapt from . import errors as e from .abc import Buffer, ConnectionType, PQGen, Transformer -from .adapt import PyFormat from ._compat import create_task from ._cmodule import _psycopg from ._encodings import pgconn_encoding from .generators import copy_from, copy_to, copy_end if TYPE_CHECKING: - from .pq.abc import PGresult from .cursor import BaseCursor, Cursor from .cursor_async import AsyncCursor from .connection import Connection # noqa: F401 from .connection_async import AsyncConnection # noqa: F401 -PY_TEXT = PyFormat.TEXT -PY_BINARY = PyFormat.BINARY +PY_TEXT = adapt.PyFormat.TEXT +PY_BINARY = adapt.PyFormat.BINARY TEXT = pq.Format.TEXT BINARY = pq.Format.BINARY COPY_IN = pq.ExecStatus.COPY_IN +COPY_OUT = pq.ExecStatus.COPY_OUT ACTIVE = pq.TransactionStatus.ACTIVE @@ -83,14 +83,22 @@ class BaseCopy(Generic[ConnectionType]): self.connection = cursor.connection self._pgconn = self.connection.pgconn - tx = cursor._tx - assert tx.pgresult, "The Transformer doesn't have a PGresult set" - self._pgresult: "PGresult" = tx.pgresult - - if self._pgresult.binary_tuples == TEXT: - self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn)) + result = cursor.pgresult + if result: + self._direction = result.status + if self._direction != COPY_IN and self._direction != COPY_OUT: + raise e.ProgrammingError( + "the cursor should have performed a COPY operation;" + f" its status is {pq.ExecStatus(self._direction).name} instead" + ) else: + self._direction = COPY_IN + + tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor) + if result and result.binary_tuples: self.formatter = BinaryFormatter(tx) + else: + self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn)) self._finished = False @@ -125,7 +133,7 @@ class BaseCopy(Generic[ConnectionType]): registry = self.cursor.adapters.types oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types] - if self._pgresult.status == COPY_IN: + if self._direction == COPY_IN: self.formatter.transformer.set_dumper_types(oids, self.formatter.format) else: self.formatter.transformer.set_loader_types(oids, self.formatter.format) @@ -160,20 +168,6 @@ class BaseCopy(Generic[ConnectionType]): return row - def _end_copy_in_gen(self, exc: Optional[BaseException]) -> PQGen[None]: - bmsg: Optional[bytes] - if exc: - msg = f"error from Python: {type(exc).__qualname__} - {exc}" - bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") - else: - bmsg = None - - res = yield from copy_end(self._pgconn, bmsg) - - nrows = res.command_tuples - self.cursor._rowcount = nrows if nrows is not None else -1 - self._finished = True - def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]: if not exc: return @@ -205,7 +199,7 @@ class Copy(BaseCopy["Connection[Any]"]): def __init__(self, cursor: "Cursor[Any]", *, writer: Optional["Writer"] = None): super().__init__(cursor) if not writer: - writer = QueueWriter(cursor.connection) + writer = QueueWriter(cursor) self.writer = writer self._write = writer.write @@ -288,12 +282,12 @@ class Copy(BaseCopy["Connection[Any]"]): by exit. It is available if, despite what is documented, you end up using the `Copy` object outside a block. """ - if self._pgresult.status == COPY_IN: + if self._direction == COPY_IN: data = self.formatter.end() if data: self._write(data) - self.writer.finish() - self.connection.wait(self._end_copy_in_gen(exc)) + self.writer.finish(exc) + self._finished = True else: self.connection.wait(self._end_copy_out_gen(exc)) @@ -310,7 +304,7 @@ class Writer(ABC): """ ... - def finish(self) -> None: + def finish(self, exc: Optional[BaseException] = None) -> None: """ Called when write operations are finished. """ @@ -318,8 +312,9 @@ class Writer(ABC): class ConnectionWriter(Writer): - def __init__(self, connection: "Connection[Any]"): - self.connection = connection + def __init__(self, cursor: "Cursor[Any]"): + self.cursor = cursor + self.connection = cursor.connection self._pgconn = self.connection.pgconn def write(self, data: Buffer) -> None: @@ -335,6 +330,18 @@ class ConnectionWriter(Writer): copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) ) + def finish(self, exc: Optional[BaseException] = None) -> None: + bmsg: Optional[bytes] + if exc: + msg = f"error from Python: {type(exc).__qualname__} - {exc}" + bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") + else: + bmsg = None + + res = self.connection.wait(copy_end(self._pgconn, bmsg)) + nrows = res.command_tuples + self.cursor._rowcount = nrows if nrows is not None else -1 + class QueueWriter(ConnectionWriter): """ @@ -345,8 +352,8 @@ class QueueWriter(ConnectionWriter): on the connection. """ - def __init__(self, connection: "Connection[Any]"): - super().__init__(connection) + def __init__(self, cursor: "Cursor[Any]"): + super().__init__(cursor) self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE) self._worker: Optional[threading.Thread] = None @@ -391,7 +398,7 @@ class QueueWriter(ConnectionWriter): for i in range(0, len(data), MAX_BUFFER_SIZE): self._queue.put(data[i : i + MAX_BUFFER_SIZE]) - def finish(self) -> None: + def finish(self, exc: Optional[BaseException] = None) -> None: self._queue.put(b"") if self._worker: @@ -402,6 +409,8 @@ class QueueWriter(ConnectionWriter): if self._worker_error: raise self._worker_error + super().finish(exc) + class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): """Manage an asynchronous :sql:`COPY` operation.""" @@ -416,7 +425,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): super().__init__(cursor) if not writer: - writer = AsyncQueueWriter(cursor.connection) + writer = AsyncQueueWriter(cursor) self.writer = writer self._write = writer.write @@ -464,12 +473,12 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): await self._write(data) async def finish(self, exc: Optional[BaseException]) -> None: - if self._pgresult.status == COPY_IN: + if self._direction == COPY_IN: data = self.formatter.end() if data: await self._write(data) - await self.writer.finish() - await self.connection.wait(self._end_copy_in_gen(exc)) + await self.writer.finish(exc) + self._finished = True else: await self.connection.wait(self._end_copy_out_gen(exc)) @@ -486,7 +495,7 @@ class AsyncWriter(ABC): """ ... - async def finish(self) -> None: + async def finish(self, exc: Optional[BaseException] = None) -> None: """ Called when write operations are finished. """ @@ -494,8 +503,9 @@ class AsyncWriter(ABC): class AsyncConnectionWriter(AsyncWriter): - def __init__(self, connection: "AsyncConnection[Any]"): - self.connection = connection + def __init__(self, cursor: "AsyncCursor[Any]"): + self.cursor = cursor + self.connection = cursor.connection self._pgconn = self.connection.pgconn async def write(self, data: Buffer) -> None: @@ -511,6 +521,18 @@ class AsyncConnectionWriter(AsyncWriter): copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) ) + async def finish(self, exc: Optional[BaseException] = None) -> None: + bmsg: Optional[bytes] + if exc: + msg = f"error from Python: {type(exc).__qualname__} - {exc}" + bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") + else: + bmsg = None + + res = await self.connection.wait(copy_end(self._pgconn, bmsg)) + nrows = res.command_tuples + self.cursor._rowcount = nrows if nrows is not None else -1 + class AsyncQueueWriter(AsyncConnectionWriter): """ @@ -521,8 +543,8 @@ class AsyncQueueWriter(AsyncConnectionWriter): on the connection. """ - def __init__(self, connection: "AsyncConnection[Any]"): - super().__init__(connection) + def __init__(self, cursor: "AsyncCursor[Any]"): + super().__init__(cursor) self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=QUEUE_SIZE) self._worker: Optional[asyncio.Future[None]] = None @@ -554,13 +576,15 @@ class AsyncQueueWriter(AsyncConnectionWriter): for i in range(0, len(data), MAX_BUFFER_SIZE): await self._queue.put(data[i : i + MAX_BUFFER_SIZE]) - async def finish(self) -> None: + async def finish(self, exc: Optional[BaseException] = None) -> None: await self._queue.put(b"") if self._worker: await asyncio.gather(self._worker) self._worker = None # break reference loops if any + await super().finish(exc) + class Formatter(ABC): """ diff --git a/tests/test_copy.py b/tests/test_copy.py index 25bc1a5dd..78a7ebec7 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -46,6 +46,8 @@ sample_binary_rows = [ sample_binary = b"".join(sample_binary_rows) +special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"} + @pytest.mark.parametrize("format", Format) def test_copy_out_read(conn, format): @@ -458,6 +460,29 @@ from copy_in group by 1, 2, 3 assert data == [(True, True, 1, 256)] +def test_copy_in_format(conn): + writer = BytesWriter() + conn.execute("set client_encoding to utf8") + cur = conn.cursor() + with psycopg.copy.Copy(cur, writer=writer) as copy: + for i in range(1, 256): + copy.write_row((i, chr(i))) + + writer.file.seek(0) + rows = writer.file.read().split(b"\n") + assert not rows[-1] + del rows[-1] + + for i, row in enumerate(rows, start=1): + fields = row.split(b"\t") + assert len(fields) == 2 + assert int(fields[0].decode()) == i + if i in special_chars: + assert fields[1].decode() == f"\\{special_chars[i]}" + else: + assert fields[1].decode() == chr(i) + + @pytest.mark.slow def test_copy_from_to(conn): # Roundtrip from file to database to file blockwise @@ -620,7 +645,7 @@ def test_worker_error_propagated(conn, monkeypatch): ) def test_connection_writer(conn, format, buffer): cur = conn.cursor() - writer = psycopg.copy.ConnectionWriter(conn) + writer = psycopg.copy.ConnectionWriter(cur) ensure_table(cur, sample_tabledef) with cur.copy( @@ -832,3 +857,11 @@ class DataGenerator: block = block.encode() m.update(block) return m.hexdigest() + + +class BytesWriter(psycopg.copy.Writer): + def __init__(self): + self.file = BytesIO() + + def write(self, data): + self.file.write(data) diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 045ad8514..bbaf77aaa 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -20,7 +20,7 @@ from psycopg.types.numeric import Int4 from .utils import alist, eur, gc_collect from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa from .test_copy import sample_values, sample_records, sample_tabledef -from .test_copy import py_to_raw +from .test_copy import py_to_raw, special_chars pytestmark = [ pytest.mark.asyncio, @@ -462,6 +462,29 @@ from copy_in group by 1, 2, 3 assert data == [(True, True, 1, 256)] +async def test_copy_in_format(aconn): + writer = AsyncBytesWriter() + await aconn.execute("set client_encoding to utf8") + cur = aconn.cursor() + async with psycopg.copy.AsyncCopy(cur, writer=writer) as copy: + for i in range(1, 256): + await copy.write_row((i, chr(i))) + + writer.file.seek(0) + rows = writer.file.read().split(b"\n") + assert not rows[-1] + del rows[-1] + + for i, row in enumerate(rows, start=1): + fields = row.split(b"\t") + assert len(fields) == 2 + assert int(fields[0].decode()) == i + if i in special_chars: + assert fields[1].decode() == f"\\{special_chars[i]}" + else: + assert fields[1].decode() == chr(i) + + @pytest.mark.slow async def test_copy_from_to(aconn): # Roundtrip from file to database to file blockwise @@ -625,7 +648,7 @@ async def test_worker_error_propagated(aconn, monkeypatch): ) async def test_connection_writer(aconn, format, buffer): cur = aconn.cursor() - writer = psycopg.copy.AsyncConnectionWriter(aconn) + writer = psycopg.copy.AsyncConnectionWriter(cur) await ensure_table(cur, sample_tabledef) async with cur.copy( @@ -827,3 +850,11 @@ class DataGenerator: block = block.encode() m.update(block) return m.hexdigest() + + +class AsyncBytesWriter(psycopg.copy.AsyncWriter): + def __init__(self): + self.file = BytesIO() + + async def write(self, data): + self.file.write(data)