From: Daniele Varrazzo Date: Fri, 15 Jan 2021 13:55:29 +0000 (+0100) Subject: Added parallel worker to COPY FROM X-Git-Tag: 3.0.dev0~155 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=900c85905442a787b48e7db653d9c8e9d0c986cc;p=thirdparty%2Fpsycopg.git Added parallel worker to COPY FROM Have a background writer to push data down the connection while the main thread is CPU-bound to encode data. Use a thread to implement the sync worker, a task to implement the asyncio one. --- diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index 20c8beaa4..ac0850f7e 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -5,7 +5,10 @@ psycopg3 copy support # Copyright (C) 2020 The Psycopg Team import re +import queue import struct +import asyncio +import threading from types import TracebackType from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple @@ -19,7 +22,7 @@ from .generators import copy_from, copy_to, copy_end if TYPE_CHECKING: from .pq.proto import PGresult - from .cursor import BaseCursor # noqa: F401 + from .cursor import BaseCursor, Cursor, AsyncCursor from .connection import Connection, AsyncConnection # noqa: F401 TEXT = pq.Format.TEXT @@ -27,6 +30,26 @@ BINARY = pq.Format.BINARY class BaseCopy(Generic[ConnectionType]): + """ + Base implementation for copy user interface + + Two subclasses expose real methods with the sync/async differences. + + The difference between the text and binary format is managed by two + different `Formatter` subclasses. + + While the interface doesn't dictate it, both subclasses are implemented + with a worker to perform I/O related work, consuming the data provided in + the correct format from a queue, while the main thread is concerned with + formatting the data in copy format and adding it to the queue. + """ + + # Size of data to accumulate before sending it down the network + BUFFER_SIZE = 32 * 1024 + + # Max size of the write queue of buffers. More than that copy will block + QUEUE_SIZE = 1024 + def __init__(self, cursor: "BaseCursor[ConnectionType]"): self.cursor = cursor self.connection = cursor.connection @@ -43,7 +66,6 @@ class BaseCopy(Generic[ConnectionType]): self._signature_sent = False self._row_mode = False # true if the user is using send_row() self._write_buffer = bytearray() - self._write_buffer_size = 32 * 1024 self._finished = False if self.format == TEXT: @@ -104,13 +126,12 @@ class BaseCopy(Generic[ConnectionType]): return self._parse_row(data, self.transformer) - def _write_gen(self, buffer: Union[str, bytes]) -> PQGen[None]: - # if write() was called, assume the header was sent together with the - # first block of data. + def _format_write(self, buffer: Union[str, bytes]) -> bytes: + data = self._ensure_bytes(buffer) self._signature_sent = True - yield from copy_to(self._pgconn, self._ensure_bytes(buffer)) + return data - def _write_row_gen(self, row: Sequence[Any]) -> PQGen[None]: + def _format_write_row(self, row: Sequence[Any]) -> bytes: # Note down that we are writing in row mode: it means we will have # to take care of the end-of-copy marker too self._row_mode = True @@ -120,57 +141,46 @@ class BaseCopy(Generic[ConnectionType]): self._signature_sent = True self._format_row(row, self.transformer, self._write_buffer) - if len(self._write_buffer) > self._write_buffer_size: - yield from copy_to(self._pgconn, self._write_buffer) - self._write_buffer.clear() - - def _finish_gen(self, error: str = "") -> PQGen[None]: - if error: - berr = error.encode(self.connection.client_encoding, "replace") - res = yield from copy_end(self._pgconn, berr) + if len(self._write_buffer) > self.BUFFER_SIZE: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + else: + return b"" + + def _format_end(self) -> bytes: + if self.format == BINARY: + # If we have sent no data we need to send the signature + # and the trailer + if not self._signature_sent: + self._write_buffer += _binary_signature + self._write_buffer += _binary_trailer + elif self._row_mode: + + # if we have sent data already, we have sent the signature + # too (either with the first row, or we assume that in + # block mode the signature is included). + # Write the trailer only if we are sending rows (with the + # assumption that who is copying binary data is sending the + # whole format). + self._write_buffer += _binary_trailer + + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + + def _end_copy_gen(self, exc: Optional[BaseException]) -> PQGen[None]: + bmsg: Optional[bytes] + if exc: + msg = f"error from Python: {type(exc).__qualname__} - {exc}" + bmsg = msg.encode(self.connection.client_encoding, "replace") else: - if self.format == BINARY: - # If we have sent no data we need to send the signature - # and the trailer - if not self._signature_sent: - self._write_buffer += _binary_signature - self._write_buffer += _binary_trailer - elif self._row_mode: - - # if we have sent data already, we have sent the signature - # too (either with the first row, or we assume that in - # block mode the signature is included). - # Write the trailer only if we are sending rows (with the - # assumption that who is copying binary data is sending the - # whole format). - self._write_buffer += _binary_trailer - - if self._write_buffer: - yield from copy_to(self._pgconn, self._write_buffer) - self._write_buffer.clear() - res = yield from copy_end(self._pgconn, None) + bmsg = None + + res = yield from copy_end(self._pgconn, bmsg) nrows = res.command_tuples self.cursor._rowcount = nrows if nrows is not None else -1 self._finished = True - def _exit_gen( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - ) -> PQGen[None]: - # no-op in COPY TO - if self._pgresult.status == ExecStatus.COPY_OUT: - return - - # In case of error in Python let's quit it here - if exc_type: - yield from self._finish_gen( - f"error from Python: {exc_type.__qualname__} - {exc_val}" - ) - else: - yield from self._finish_gen() - # Support methods def _ensure_bytes(self, data: Union[bytes, str]) -> bytes: @@ -197,6 +207,35 @@ class Copy(BaseCopy["Connection"]): __module__ = "psycopg3" + def __init__(self, cursor: "Cursor"): + super().__init__(cursor) + self._queue: queue.Queue[Optional[bytes]] = queue.Queue( + maxsize=self.QUEUE_SIZE + ) + self._worker: Optional[threading.Thread] = None + + def __enter__(self) -> "Copy": + self._check_reuse() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.finish(exc_val) + + # End user sync interface + + def __iter__(self) -> Iterator[memoryview]: + """Implement block-by-block iteration on :sql:`COPY TO`.""" + while True: + data = self.read() + if not data: + break + yield data + def read(self) -> memoryview: """ Read an unparsed row after a :sql:`COPY TO` operation. @@ -236,41 +275,96 @@ class Copy(BaseCopy["Connection"]): If the :sql:`COPY` is in binary format *buffer* must be `!bytes`. In text mode it can be either `!bytes` or `!str`. """ - self.connection.wait(self._write_gen(buffer)) + data = self._format_write(buffer) + self._write(data) def write_row(self, row: Sequence[Any]) -> None: """Write a record to a table after a :sql:`COPY FROM` operation.""" - self.connection.wait(self._write_row_gen(row)) + data = self._format_write_row(row) + self._write(data) - def _finish(self, error: str = "") -> None: - """Terminate a :sql:`COPY FROM` operation.""" - self.connection.wait(self._finish_gen(error)) + def finish(self, exc: Optional[BaseException]) -> None: + """Terminate the copy operation and free the resources allocated. - def __enter__(self) -> "Copy": + You shouldn't need to call this function yourself: it is usually called + by exit. It is available if, despite what is documented, you end up + using the `Copy` object outside a block. + """ + # no-op in COPY TO + if self._pgresult.status == ExecStatus.COPY_OUT: + return + + self._write_end() + self.connection.wait(self._end_copy_gen(exc)) + + # Concurrent copy support + + def worker(self) -> None: + """Push data to the server when available from the copy queue. + + Terminate reading when the queue receives a None. + + The function is designed to be run in a separate thread. + """ + while 1: + data = self._queue.get(block=True, timeout=24 * 60 * 60) + if not data: + break + self.connection.wait(copy_to(self._pgconn, data)) + + def _write(self, data: bytes) -> None: + if not data: + return + + if not self._worker: + # warning: reference loop, broken by _write_end + self._worker = threading.Thread(target=self.worker) + self._worker.daemon = True + self._worker.start() + + self._queue.put(data) + + def _write_end(self) -> None: + data = self._format_end() + self._write(data) + self._queue.put(None) + + if self._worker: + self._worker.join() + self._worker = None # break the loop + + +class AsyncCopy(BaseCopy["AsyncConnection"]): + """Manage an asynchronous :sql:`COPY` operation.""" + + __module__ = "psycopg3" + + def __init__(self, cursor: "AsyncCursor"): + super().__init__(cursor) + self._queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue( + maxsize=self.QUEUE_SIZE + ) + self._worker: Optional[asyncio.Future[None]] = None + + async def __aenter__(self) -> "AsyncCopy": self._check_reuse() return self - def __exit__( + async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: - self.connection.wait(self._exit_gen(exc_type, exc_val)) + await self.finish(exc_val) - def __iter__(self) -> Iterator[memoryview]: + async def __aiter__(self) -> AsyncIterator[memoryview]: while True: - data = self.read() + data = await self.read() if not data: break yield data - -class AsyncCopy(BaseCopy["AsyncConnection"]): - """Manage an asynchronous :sql:`COPY` operation.""" - - __module__ = "psycopg3" - async def read(self) -> memoryview: return await self.connection.wait(self._read_gen()) @@ -285,32 +379,54 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): return await self.connection.wait(self._read_row_gen()) async def write(self, buffer: Union[str, bytes]) -> None: - await self.connection.wait(self._write_gen(buffer)) + data = self._format_write(buffer) + await self._write(data) async def write_row(self, row: Sequence[Any]) -> None: - await self.connection.wait(self._write_row_gen(row)) + data = self._format_write_row(row) + await self._write(data) - async def _finish(self, error: str = "") -> None: - await self.connection.wait(self._finish_gen(error)) + async def finish(self, exc: Optional[BaseException]) -> None: + # no-op in COPY TO + if self._pgresult.status == ExecStatus.COPY_OUT: + return - async def __aenter__(self) -> "AsyncCopy": - self._check_reuse() - return self + await self._write_end() + await self.connection.wait(self._end_copy_gen(exc)) - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - await self.connection.wait(self._exit_gen(exc_type, exc_val)) + # Concurrent copy support - async def __aiter__(self) -> AsyncIterator[memoryview]: - while True: - data = await self.read() + async def worker(self) -> None: + """Push data to the server when available from the copy queue. + + Terminate reading when the queue receives a None. + + The function is designed to be run in a separate thread. + """ + while 1: + data = await self._queue.get() if not data: break - yield data + await self.connection.wait(copy_to(self._pgconn, data)) + + async def _write(self, data: bytes) -> None: + if not data: + return + + if not self._worker: + # TODO: can be asyncio.create_task once Python 3.6 is dropped + self._worker = asyncio.ensure_future(self.worker()) + + await self._queue.put(data) + + async def _write_end(self) -> None: + data = self._format_end() + await self._write(data) + await self._queue.put(None) + + if self._worker: + await asyncio.gather(self._worker) + self._worker = None # break reference loops if any def _format_row_text( diff --git a/tests/test_copy.py b/tests/test_copy.py index 63e4242bc..ea12f1e00 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -479,6 +479,23 @@ def test_str(conn): assert "[INTRANS]" in str(copy) +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +def test_worker_life(conn, format, buffer): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + assert not copy._worker + copy.write(globals()[buffer]) + assert copy._worker + + assert not copy._worker + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + def py_to_raw(item, fmt): """Convert from Python type to the expected result from the db""" if fmt == Format.TEXT: diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 75855743c..6dced8dd9 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -452,6 +452,26 @@ async def test_str(aconn): assert "[INTRANS]" in str(copy) +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +async def test_worker_life(aconn, format, buffer): + cur = await aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy( + f"copy copy_in from stdin (format {format.name})" + ) as copy: + assert not copy._worker + await copy.write(globals()[buffer]) + assert copy._worker + + assert not copy._worker + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + async def ensure_table(cur, tabledef, name="copy_in"): await cur.execute(f"drop table if exists {name}") await cur.execute(f"create table {name} ({tabledef})")