From: Daniele Varrazzo Date: Sat, 11 Jun 2022 14:21:38 +0000 (+0200) Subject: refactor(copy): add Writer object X-Git-Tag: 3.1~44^2~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=31ae2dc123b863787f3cf174888fa30478069708;p=thirdparty%2Fpsycopg.git refactor(copy): add Writer object This separation of responsibility will allow to: - create writers that e.g. produce a file rather than writing to a connection; - create writers that write directly instead of using a queue (see https://github.com/psycopg/psycopg/discussions/111 for an example of users unhappy about that. I still don't see their nervousness justified, but I'm ok to provide them an escape hatch if it's easy to implement). --- diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index ffa5067ef..80d121f08 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -40,6 +40,24 @@ COPY_IN = pq.ExecStatus.COPY_IN ACTIVE = pq.TransactionStatus.ACTIVE +# Size of data to accumulate before sending it down the network. We fill a +# buffer this size field by field, and when it passes the threshold size +# we ship it, so it may end up being bigger than this. +BUFFER_SIZE = 32 * 1024 + +# Maximum data size we want to queue to send to the libpq copy. Sending a +# buffer too big to be handled can cause an infinite loop in the libpq +# (#255) so we want to split it in more digestable chunks. +MAX_BUFFER_SIZE = 4 * BUFFER_SIZE +# Note: making this buffer too large, e.g. +# MAX_BUFFER_SIZE = 1024 * 1024 +# makes operations *way* slower! Probably triggering some quadraticity +# in the libpq memory management and data sending. + +# Max size of the write queue of buffers. More than that copy will block +# Each buffer should be around BUFFER_SIZE size. +QUEUE_SIZE = 1024 + class BaseCopy(Generic[ConnectionType]): """ @@ -58,25 +76,8 @@ class BaseCopy(Generic[ConnectionType]): _Self = TypeVar("_Self", bound="BaseCopy[ConnectionType]") - # Max size of the write queue of buffers. More than that copy will block - # Each buffer should be around BUFFER_SIZE size. - QUEUE_SIZE = 1024 - - # Size of data to accumulate before sending it down the network. We fill a - # buffer this size field by field, and when it passes the threshold size - # we ship it, so it may end up being bigger than this. - BUFFER_SIZE = 32 * 1024 - - # Maximum data size we want to queue to send to the libpq copy. Sending a - # buffer too big to be handled can cause an infinite loop in the libpq - # (#255) so we want to split it in more digestable chunks. - MAX_BUFFER_SIZE = 4 * BUFFER_SIZE - # Note: making this buffer too large, e.g. - # MAX_BUFFER_SIZE = 1024 * 1024 - # makes operations *way* slower! Probably triggering some quadraticity - # in the libpq memory management and data sending. - formatter: "Formatter" + writer: "Writer[ConnectionType]" def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"): self.cursor = cursor @@ -200,11 +201,16 @@ class Copy(BaseCopy["Connection[Any]"]): __module__ = "psycopg" - def __init__(self, cursor: "Cursor[Any]"): + def __init__( + self, cursor: "Cursor[Any]", writer: Optional["Writer[Connection[Any]]"] = None + ): super().__init__(cursor) - self._queue: queue.Queue[bytes] = queue.Queue(maxsize=self.QUEUE_SIZE) - self._worker: Optional[threading.Thread] = None - self._worker_error: Optional[BaseException] = None + if not writer: + writer = QueueWriter(cursor) + + self.writer = writer + self._write = writer.write + self._write_end = writer.write_end def __enter__(self: BaseCopy._Self) -> BaseCopy._Self: self._enter() @@ -285,66 +291,12 @@ class Copy(BaseCopy["Connection[Any]"]): using the `Copy` object outside a block. """ if self._pgresult.status == COPY_IN: - self._write_end() + data = self.formatter.end() + self._write_end(data) self.connection.wait(self._end_copy_in_gen(exc)) else: self.connection.wait(self._end_copy_out_gen(exc)) - # Concurrent copy support - - def worker(self) -> None: - """Push data to the server when available from the copy queue. - - Terminate reading when the queue receives a false-y value, or in case - of error. - - The function is designed to be run in a separate thread. - """ - try: - while True: - data = self._queue.get(block=True, timeout=24 * 60 * 60) - if not data: - break - self.connection.wait(copy_to(self._pgconn, data)) - except BaseException as ex: - # Propagate the error to the main thread. - self._worker_error = ex - - def _write(self, data: Buffer) -> None: - if not self._worker: - # warning: reference loop, broken by _write_end - self._worker = threading.Thread(target=self.worker) - self._worker.daemon = True - self._worker.start() - - # If the worker thread raies an exception, re-raise it to the caller. - if self._worker_error: - raise self._worker_error - - if len(data) <= self.MAX_BUFFER_SIZE: - # Most used path: we don't need to split the buffer in smaller - # bits, so don't make a copy. - self._queue.put(data) - else: - # Copy a buffer too large in chunks to avoid causing a memory - # error in the libpq, which may cause an infinite loop (#255). - for i in range(0, len(data), self.MAX_BUFFER_SIZE): - self._queue.put(data[i : i + self.MAX_BUFFER_SIZE]) - - def _write_end(self) -> None: - data = self.formatter.end() - if data: - self._write(data) - self._queue.put(b"") - - if self._worker: - self._worker.join() - self._worker = None # break the loop - - # Check if the worker thread raised any exception before terminating. - if self._worker_error: - raise self._worker_error - class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): """Manage an asynchronous :sql:`COPY` operation.""" @@ -353,7 +305,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): def __init__(self, cursor: "AsyncCursor[Any]"): super().__init__(cursor) - self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=self.QUEUE_SIZE) + self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=QUEUE_SIZE) self._worker: Optional[asyncio.Future[None]] = None async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self: @@ -424,15 +376,15 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): if not self._worker: self._worker = create_task(self.worker()) - if len(data) <= self.MAX_BUFFER_SIZE: + if len(data) <= MAX_BUFFER_SIZE: # Most used path: we don't need to split the buffer in smaller # bits, so don't make a copy. await self._queue.put(data) else: # Copy a buffer too large in chunks to avoid causing a memory # error in the libpq, which may cause an infinite loop (#255). - for i in range(0, len(data), self.MAX_BUFFER_SIZE): - await self._queue.put(data[i : i + self.MAX_BUFFER_SIZE]) + for i in range(0, len(data), MAX_BUFFER_SIZE): + await self._queue.put(data[i : i + MAX_BUFFER_SIZE]) async def _write_end(self) -> None: data = self.formatter.end() @@ -445,13 +397,101 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): self._worker = None # break reference loops if any +class Writer(Generic[ConnectionType], ABC): + """ + A class to write copy data somewhere. + """ + + @abstractmethod + def write(self, data: Buffer) -> None: + ... + + @abstractmethod + def write_end(self, data: Buffer) -> None: + ... + + +class ConnectionWriter(Writer[ConnectionType]): + def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"): + self.connection = cursor.connection + self._pgconn = self.connection.pgconn + + +class QueueWriter(ConnectionWriter["Connection[Any]"]): + """ + A writer using a buffer to queue data to write. + + `write()` returns immediately, so that the main thread can be CPU-bound + formatting messages, while a worker thread can be IO-bound waiting to write + on the connection. + """ + + def __init__(self, cursor: "Cursor[Any]"): + super().__init__(cursor) + + self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE) + self._worker: Optional[threading.Thread] = None + self._worker_error: Optional[BaseException] = None + + def worker(self) -> None: + """Push data to the server when available from the copy queue. + + Terminate reading when the queue receives a false-y value, or in case + of error. + + The function is designed to be run in a separate thread. + """ + try: + while True: + data = self._queue.get(block=True, timeout=24 * 60 * 60) + if not data: + break + self.connection.wait(copy_to(self._pgconn, data)) + except BaseException as ex: + # Propagate the error to the main thread. + self._worker_error = ex + + def write(self, data: Buffer) -> None: + if not self._worker: + # warning: reference loop, broken by _write_end + self._worker = threading.Thread(target=self.worker) + self._worker.daemon = True + self._worker.start() + + # If the worker thread raies an exception, re-raise it to the caller. + if self._worker_error: + raise self._worker_error + + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + self._queue.put(data) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + self._queue.put(data[i : i + MAX_BUFFER_SIZE]) + + def write_end(self, data: Buffer) -> None: + if data: + self.write(data) + self._queue.put(b"") + + if self._worker: + self._worker.join() + self._worker = None # break the loop + + # Check if the worker thread raised any exception before terminating. + if self._worker_error: + raise self._worker_error + + class Formatter(ABC): """ A class which understand a copy format (text, binary). """ format: pq.Format - BUFFER_SIZE = BaseCopy.BUFFER_SIZE def __init__(self, transformer: Transformer): self.transformer = transformer @@ -500,7 +540,7 @@ class TextFormatter(Formatter): self._row_mode = True format_row_text(row, self.transformer, self._write_buffer) - if len(self._write_buffer) > self.BUFFER_SIZE: + if len(self._write_buffer) > BUFFER_SIZE: buffer, self._write_buffer = self._write_buffer, bytearray() return buffer else: @@ -557,7 +597,7 @@ class BinaryFormatter(Formatter): self._signature_sent = True format_row_binary(row, self.transformer, self._write_buffer) - if len(self._write_buffer) > self.BUFFER_SIZE: + if len(self._write_buffer) > BUFFER_SIZE: buffer, self._write_buffer = self._write_buffer, bytearray() return buffer else: diff --git a/tests/test_copy.py b/tests/test_copy.py index c02483a3b..e23c1baf4 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -592,11 +592,11 @@ def test_worker_life(conn, format, buffer): cur = conn.cursor() ensure_table(cur, sample_tabledef) with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: - assert not copy._worker + assert not copy.writer._worker copy.write(globals()[buffer]) - assert copy._worker + assert copy.writer._worker - assert not copy._worker + assert not copy.writer._worker data = cur.execute("select * from copy_in order by 1").fetchall() assert data == sample_records