From: Daniele Varrazzo Date: Sat, 11 Jun 2022 15:19:46 +0000 (+0200) Subject: feat(copy): add writer param AsyncCursor.copy() X-Git-Tag: 3.1~44^2~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=502dff2330287523b904994f646ab3c022fc8ae7;p=thirdparty%2Fpsycopg.git feat(copy): add writer param AsyncCursor.copy() Add async copy writers similar to the sync ones. --- diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index 42ec488c9..b9d641921 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -77,7 +77,6 @@ class BaseCopy(Generic[ConnectionType]): _Self = TypeVar("_Self", bound="BaseCopy[ConnectionType]") formatter: "Formatter" - writer: "Writer[ConnectionType]" def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"): self.cursor = cursor @@ -201,12 +200,9 @@ class Copy(BaseCopy["Connection[Any]"]): __module__ = "psycopg" - def __init__( - self, - cursor: "Cursor[Any]", - *, - writer: Optional["Writer[Connection[Any]]"] = None, - ): + writer: "Writer" + + def __init__(self, cursor: "Cursor[Any]", *, writer: Optional["Writer"] = None): super().__init__(cursor) if not writer: writer = QueueWriter(cursor.connection) @@ -302,15 +298,128 @@ class Copy(BaseCopy["Connection[Any]"]): self.connection.wait(self._end_copy_out_gen(exc)) +class Writer(ABC): + """ + A class to write copy data somewhere. + """ + + @abstractmethod + def write(self, data: Buffer) -> None: + """ + Write some data to destination. + """ + ... + + def finish(self) -> None: + """ + Called when write operations are finished. + """ + pass + + +class ConnectionWriter(Writer): + def __init__(self, connection: "Connection[Any]"): + self.connection = connection + self._pgconn = self.connection.pgconn + + def write(self, data: Buffer) -> None: + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + self.connection.wait(copy_to(self._pgconn, data)) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + self.connection.wait( + copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) + ) + + +class QueueWriter(ConnectionWriter): + """ + A writer using a buffer to queue data to write. + + `write()` returns immediately, so that the main thread can be CPU-bound + formatting messages, while a worker thread can be IO-bound waiting to write + on the connection. + """ + + def __init__(self, connection: "Connection[Any]"): + super().__init__(connection) + + self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE) + self._worker: Optional[threading.Thread] = None + self._worker_error: Optional[BaseException] = None + + def worker(self) -> None: + """Push data to the server when available from the copy queue. + + Terminate reading when the queue receives a false-y value, or in case + of error. + + The function is designed to be run in a separate thread. + """ + try: + while True: + data = self._queue.get(block=True, timeout=24 * 60 * 60) + if not data: + break + self.connection.wait(copy_to(self._pgconn, data)) + except BaseException as ex: + # Propagate the error to the main thread. + self._worker_error = ex + + def write(self, data: Buffer) -> None: + if not self._worker: + # warning: reference loop, broken by _write_end + self._worker = threading.Thread(target=self.worker) + self._worker.daemon = True + self._worker.start() + + # If the worker thread raies an exception, re-raise it to the caller. + if self._worker_error: + raise self._worker_error + + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + self._queue.put(data) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + self._queue.put(data[i : i + MAX_BUFFER_SIZE]) + + def finish(self) -> None: + self._queue.put(b"") + + if self._worker: + self._worker.join() + self._worker = None # break the loop + + # Check if the worker thread raised any exception before terminating. + if self._worker_error: + raise self._worker_error + + class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): """Manage an asynchronous :sql:`COPY` operation.""" __module__ = "psycopg" - def __init__(self, cursor: "AsyncCursor[Any]"): + writer: "AsyncWriter" + + def __init__( + self, cursor: "AsyncCursor[Any]", *, writer: Optional["AsyncWriter"] = None + ): super().__init__(cursor) - self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=QUEUE_SIZE) - self._worker: Optional[asyncio.Future[None]] = None + + if not writer: + writer = AsyncQueueWriter(cursor.connection) + + self.writer = writer + self._write = writer.write async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self: self._enter() @@ -356,90 +465,54 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): async def finish(self, exc: Optional[BaseException]) -> None: if self._pgresult.status == COPY_IN: - await self._write_end() + data = self.formatter.end() + if data: + await self._write(data) + await self.writer.finish() await self.connection.wait(self._end_copy_in_gen(exc)) else: await self.connection.wait(self._end_copy_out_gen(exc)) - # Concurrent copy support - async def worker(self) -> None: - """Push data to the server when available from the copy queue. - - Terminate reading when the queue receives a false-y value. - - The function is designed to be run in a separate thread. - """ - while True: - data = await self._queue.get() - if not data: - break - await self.connection.wait(copy_to(self._pgconn, data)) - - async def _write(self, data: Buffer) -> None: - if not self._worker: - self._worker = create_task(self.worker()) - - if len(data) <= MAX_BUFFER_SIZE: - # Most used path: we don't need to split the buffer in smaller - # bits, so don't make a copy. - await self._queue.put(data) - else: - # Copy a buffer too large in chunks to avoid causing a memory - # error in the libpq, which may cause an infinite loop (#255). - for i in range(0, len(data), MAX_BUFFER_SIZE): - await self._queue.put(data[i : i + MAX_BUFFER_SIZE]) - - async def _write_end(self) -> None: - data = self.formatter.end() - if data: - await self._write(data) - await self._queue.put(b"") - - if self._worker: - await asyncio.gather(self._worker) - self._worker = None # break reference loops if any - - -class Writer(Generic[ConnectionType], ABC): +class AsyncWriter(ABC): """ A class to write copy data somewhere. """ @abstractmethod - def write(self, data: Buffer) -> None: + async def write(self, data: Buffer) -> None: """ Write some data to destination. """ ... - def finish(self) -> None: + async def finish(self) -> None: """ Called when write operations are finished. """ pass -class ConnectionWriter(Writer["Connection[Any]"]): - def __init__(self, connection: "Connection[Any]"): +class AsyncConnectionWriter(AsyncWriter): + def __init__(self, connection: "AsyncConnection[Any]"): self.connection = connection self._pgconn = self.connection.pgconn - def write(self, data: Buffer) -> None: + async def write(self, data: Buffer) -> None: if len(data) <= MAX_BUFFER_SIZE: # Most used path: we don't need to split the buffer in smaller # bits, so don't make a copy. - self.connection.wait(copy_to(self._pgconn, data)) + await self.connection.wait(copy_to(self._pgconn, data)) else: # Copy a buffer too large in chunks to avoid causing a memory # error in the libpq, which may cause an infinite loop (#255). for i in range(0, len(data), MAX_BUFFER_SIZE): - self.connection.wait( + await self.connection.wait( copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) ) -class QueueWriter(ConnectionWriter): +class AsyncQueueWriter(AsyncConnectionWriter): """ A writer using a buffer to queue data to write. @@ -448,62 +521,45 @@ class QueueWriter(ConnectionWriter): on the connection. """ - def __init__(self, connection: "Connection[Any]"): + def __init__(self, connection: "AsyncConnection[Any]"): super().__init__(connection) - self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE) - self._worker: Optional[threading.Thread] = None - self._worker_error: Optional[BaseException] = None + self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=QUEUE_SIZE) + self._worker: Optional[asyncio.Future[None]] = None - def worker(self) -> None: + async def worker(self) -> None: """Push data to the server when available from the copy queue. - Terminate reading when the queue receives a false-y value, or in case - of error. + Terminate reading when the queue receives a false-y value. - The function is designed to be run in a separate thread. + The function is designed to be run in a separate task. """ - try: - while True: - data = self._queue.get(block=True, timeout=24 * 60 * 60) - if not data: - break - self.connection.wait(copy_to(self._pgconn, data)) - except BaseException as ex: - # Propagate the error to the main thread. - self._worker_error = ex + while True: + data = await self._queue.get() + if not data: + break + await self.connection.wait(copy_to(self._pgconn, data)) - def write(self, data: Buffer) -> None: + async def write(self, data: Buffer) -> None: if not self._worker: - # warning: reference loop, broken by _write_end - self._worker = threading.Thread(target=self.worker) - self._worker.daemon = True - self._worker.start() - - # If the worker thread raies an exception, re-raise it to the caller. - if self._worker_error: - raise self._worker_error + self._worker = create_task(self.worker()) if len(data) <= MAX_BUFFER_SIZE: # Most used path: we don't need to split the buffer in smaller # bits, so don't make a copy. - self._queue.put(data) + await self._queue.put(data) else: # Copy a buffer too large in chunks to avoid causing a memory # error in the libpq, which may cause an infinite loop (#255). for i in range(0, len(data), MAX_BUFFER_SIZE): - self._queue.put(data[i : i + MAX_BUFFER_SIZE]) + await self._queue.put(data[i : i + MAX_BUFFER_SIZE]) - def finish(self) -> None: - self._queue.put(b"") + async def finish(self) -> None: + await self._queue.put(b"") if self._worker: - self._worker.join() - self._worker = None # break the loop - - # Check if the worker thread raised any exception before terminating. - if self._worker_error: - raise self._worker_error + await asyncio.gather(self._worker) + self._worker = None # break reference loops if any class Formatter(ABC): diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index bab46c2d9..77ad80788 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -878,7 +878,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): statement: Query, params: Optional[Params] = None, *, - writer: Optional[CopyWriter[Any]] = None, + writer: Optional[CopyWriter] = None, ) -> Iterator[Copy]: """ Initiate a :sql:`COPY` operation and return an object to manage it. diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index 7f71efaad..4b4844d40 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -12,7 +12,7 @@ from contextlib import asynccontextmanager from . import pq from . import errors as e from .abc import Query, Params -from .copy import AsyncCopy +from .copy import AsyncCopy, AsyncWriter as AsyncCopyWriter from .rows import Row, RowMaker, AsyncRowFactory from .cursor import BaseCursor from ._pipeline import Pipeline @@ -204,7 +204,11 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): @asynccontextmanager async def copy( - self, statement: Query, params: Optional[Params] = None + self, + statement: Query, + params: Optional[Params] = None, + *, + writer: Optional[AsyncCopyWriter] = None, ) -> AsyncIterator[AsyncCopy]: """ :rtype: AsyncCopy @@ -213,7 +217,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): async with self._conn.lock: await self._conn.wait(self._start_copy_gen(statement, params)) - async with AsyncCopy(self) as copy: + async with AsyncCopy(self, writer=writer) as copy: yield copy except e.Error as ex: raise ex.with_traceback(None) diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 3b0008dde..045ad8514 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -596,11 +596,11 @@ async def test_worker_life(aconn, format, buffer): cur = aconn.cursor() await ensure_table(cur, sample_tabledef) async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: - assert not copy._worker + assert not copy.writer._worker await copy.write(globals()[buffer]) - assert copy._worker + assert copy.writer._worker - assert not copy._worker + assert not copy.writer._worker await cur.execute("select * from copy_in order by 1") data = await cur.fetchall() assert data == sample_records @@ -619,6 +619,26 @@ async def test_worker_error_propagated(aconn, monkeypatch): await copy.write("a,b") +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +async def test_connection_writer(aconn, format, buffer): + cur = aconn.cursor() + writer = psycopg.copy.AsyncConnectionWriter(aconn) + + await ensure_table(cur, sample_tabledef) + async with cur.copy( + f"copy copy_in from stdin (format {format.name})", writer=writer + ) as copy: + assert copy.writer is writer + await copy.write(globals()[buffer]) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + @pytest.mark.slow @pytest.mark.parametrize( "fmt, set_types",