]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(copy): add Writer object
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Jun 2022 14:21:38 +0000 (16:21 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 19 Jul 2022 14:09:52 +0000 (15:09 +0100)
This separation of responsibility will allow to:

- create writers that e.g. produce a file rather than writing to a
connection;

- create writers that write directly instead of using a queue (see
https://github.com/psycopg/psycopg/discussions/111 for an example of users
unhappy about that. I still don't see their nervousness justified, but
I'm ok to provide them an escape hatch if it's easy to implement).

psycopg/psycopg/copy.py
tests/test_copy.py

index ffa5067ef3c4b977af83f5a3fe8ad8274f8e5d2e..80d121f080a5ebd6f3abb9b0e31ec8abfd948f26 100644 (file)
@@ -40,6 +40,24 @@ COPY_IN = pq.ExecStatus.COPY_IN
 
 ACTIVE = pq.TransactionStatus.ACTIVE
 
+# Size of data to accumulate before sending it down the network. We fill a
+# buffer this size field by field, and when it passes the threshold size
+# we ship it, so it may end up being bigger than this.
+BUFFER_SIZE = 32 * 1024
+
+# Maximum data size we want to queue to send to the libpq copy. Sending a
+# buffer too big to be handled can cause an infinite loop in the libpq
+# (#255) so we want to split it in more digestable chunks.
+MAX_BUFFER_SIZE = 4 * BUFFER_SIZE
+# Note: making this buffer too large, e.g.
+# MAX_BUFFER_SIZE = 1024 * 1024
+# makes operations *way* slower! Probably triggering some quadraticity
+# in the libpq memory management and data sending.
+
+# Max size of the write queue of buffers. More than that copy will block
+# Each buffer should be around BUFFER_SIZE size.
+QUEUE_SIZE = 1024
+
 
 class BaseCopy(Generic[ConnectionType]):
     """
@@ -58,25 +76,8 @@ class BaseCopy(Generic[ConnectionType]):
 
     _Self = TypeVar("_Self", bound="BaseCopy[ConnectionType]")
 
-    # Max size of the write queue of buffers. More than that copy will block
-    # Each buffer should be around BUFFER_SIZE size.
-    QUEUE_SIZE = 1024
-
-    # Size of data to accumulate before sending it down the network. We fill a
-    # buffer this size field by field, and when it passes the threshold size
-    # we ship it, so it may end up being bigger than this.
-    BUFFER_SIZE = 32 * 1024
-
-    # Maximum data size we want to queue to send to the libpq copy. Sending a
-    # buffer too big to be handled can cause an infinite loop in the libpq
-    # (#255) so we want to split it in more digestable chunks.
-    MAX_BUFFER_SIZE = 4 * BUFFER_SIZE
-    # Note: making this buffer too large, e.g.
-    # MAX_BUFFER_SIZE = 1024 * 1024
-    # makes operations *way* slower! Probably triggering some quadraticity
-    # in the libpq memory management and data sending.
-
     formatter: "Formatter"
+    writer: "Writer[ConnectionType]"
 
     def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"):
         self.cursor = cursor
@@ -200,11 +201,16 @@ class Copy(BaseCopy["Connection[Any]"]):
 
     __module__ = "psycopg"
 
-    def __init__(self, cursor: "Cursor[Any]"):
+    def __init__(
+        self, cursor: "Cursor[Any]", writer: Optional["Writer[Connection[Any]]"] = None
+    ):
         super().__init__(cursor)
-        self._queue: queue.Queue[bytes] = queue.Queue(maxsize=self.QUEUE_SIZE)
-        self._worker: Optional[threading.Thread] = None
-        self._worker_error: Optional[BaseException] = None
+        if not writer:
+            writer = QueueWriter(cursor)
+
+        self.writer = writer
+        self._write = writer.write
+        self._write_end = writer.write_end
 
     def __enter__(self: BaseCopy._Self) -> BaseCopy._Self:
         self._enter()
@@ -285,66 +291,12 @@ class Copy(BaseCopy["Connection[Any]"]):
         using the `Copy` object outside a block.
         """
         if self._pgresult.status == COPY_IN:
-            self._write_end()
+            data = self.formatter.end()
+            self._write_end(data)
             self.connection.wait(self._end_copy_in_gen(exc))
         else:
             self.connection.wait(self._end_copy_out_gen(exc))
 
-    # Concurrent copy support
-
-    def worker(self) -> None:
-        """Push data to the server when available from the copy queue.
-
-        Terminate reading when the queue receives a false-y value, or in case
-        of error.
-
-        The function is designed to be run in a separate thread.
-        """
-        try:
-            while True:
-                data = self._queue.get(block=True, timeout=24 * 60 * 60)
-                if not data:
-                    break
-                self.connection.wait(copy_to(self._pgconn, data))
-        except BaseException as ex:
-            # Propagate the error to the main thread.
-            self._worker_error = ex
-
-    def _write(self, data: Buffer) -> None:
-        if not self._worker:
-            # warning: reference loop, broken by _write_end
-            self._worker = threading.Thread(target=self.worker)
-            self._worker.daemon = True
-            self._worker.start()
-
-        # If the worker thread raies an exception, re-raise it to the caller.
-        if self._worker_error:
-            raise self._worker_error
-
-        if len(data) <= self.MAX_BUFFER_SIZE:
-            # Most used path: we don't need to split the buffer in smaller
-            # bits, so don't make a copy.
-            self._queue.put(data)
-        else:
-            # Copy a buffer too large in chunks to avoid causing a memory
-            # error in the libpq, which may cause an infinite loop (#255).
-            for i in range(0, len(data), self.MAX_BUFFER_SIZE):
-                self._queue.put(data[i : i + self.MAX_BUFFER_SIZE])
-
-    def _write_end(self) -> None:
-        data = self.formatter.end()
-        if data:
-            self._write(data)
-        self._queue.put(b"")
-
-        if self._worker:
-            self._worker.join()
-            self._worker = None  # break the loop
-
-        # Check if the worker thread raised any exception before terminating.
-        if self._worker_error:
-            raise self._worker_error
-
 
 class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
     """Manage an asynchronous :sql:`COPY` operation."""
@@ -353,7 +305,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
 
     def __init__(self, cursor: "AsyncCursor[Any]"):
         super().__init__(cursor)
-        self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=self.QUEUE_SIZE)
+        self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=QUEUE_SIZE)
         self._worker: Optional[asyncio.Future[None]] = None
 
     async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self:
@@ -424,15 +376,15 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
         if not self._worker:
             self._worker = create_task(self.worker())
 
-        if len(data) <= self.MAX_BUFFER_SIZE:
+        if len(data) <= MAX_BUFFER_SIZE:
             # Most used path: we don't need to split the buffer in smaller
             # bits, so don't make a copy.
             await self._queue.put(data)
         else:
             # Copy a buffer too large in chunks to avoid causing a memory
             # error in the libpq, which may cause an infinite loop (#255).
-            for i in range(0, len(data), self.MAX_BUFFER_SIZE):
-                await self._queue.put(data[i : i + self.MAX_BUFFER_SIZE])
+            for i in range(0, len(data), MAX_BUFFER_SIZE):
+                await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
 
     async def _write_end(self) -> None:
         data = self.formatter.end()
@@ -445,13 +397,101 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
             self._worker = None  # break reference loops if any
 
 
+class Writer(Generic[ConnectionType], ABC):
+    """
+    A class to write copy data somewhere.
+    """
+
+    @abstractmethod
+    def write(self, data: Buffer) -> None:
+        ...
+
+    @abstractmethod
+    def write_end(self, data: Buffer) -> None:
+        ...
+
+
+class ConnectionWriter(Writer[ConnectionType]):
+    def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"):
+        self.connection = cursor.connection
+        self._pgconn = self.connection.pgconn
+
+
+class QueueWriter(ConnectionWriter["Connection[Any]"]):
+    """
+    A writer using a buffer to queue data to write.
+
+    `write()` returns immediately, so that the main thread can be CPU-bound
+    formatting messages, while a worker thread can be IO-bound waiting to write
+    on the connection.
+    """
+
+    def __init__(self, cursor: "Cursor[Any]"):
+        super().__init__(cursor)
+
+        self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE)
+        self._worker: Optional[threading.Thread] = None
+        self._worker_error: Optional[BaseException] = None
+
+    def worker(self) -> None:
+        """Push data to the server when available from the copy queue.
+
+        Terminate reading when the queue receives a false-y value, or in case
+        of error.
+
+        The function is designed to be run in a separate thread.
+        """
+        try:
+            while True:
+                data = self._queue.get(block=True, timeout=24 * 60 * 60)
+                if not data:
+                    break
+                self.connection.wait(copy_to(self._pgconn, data))
+        except BaseException as ex:
+            # Propagate the error to the main thread.
+            self._worker_error = ex
+
+    def write(self, data: Buffer) -> None:
+        if not self._worker:
+            # warning: reference loop, broken by _write_end
+            self._worker = threading.Thread(target=self.worker)
+            self._worker.daemon = True
+            self._worker.start()
+
+        # If the worker thread raies an exception, re-raise it to the caller.
+        if self._worker_error:
+            raise self._worker_error
+
+        if len(data) <= MAX_BUFFER_SIZE:
+            # Most used path: we don't need to split the buffer in smaller
+            # bits, so don't make a copy.
+            self._queue.put(data)
+        else:
+            # Copy a buffer too large in chunks to avoid causing a memory
+            # error in the libpq, which may cause an infinite loop (#255).
+            for i in range(0, len(data), MAX_BUFFER_SIZE):
+                self._queue.put(data[i : i + MAX_BUFFER_SIZE])
+
+    def write_end(self, data: Buffer) -> None:
+        if data:
+            self.write(data)
+        self._queue.put(b"")
+
+        if self._worker:
+            self._worker.join()
+            self._worker = None  # break the loop
+
+        # Check if the worker thread raised any exception before terminating.
+        if self._worker_error:
+            raise self._worker_error
+
+
 class Formatter(ABC):
     """
     A class which understand a copy format (text, binary).
     """
 
     format: pq.Format
-    BUFFER_SIZE = BaseCopy.BUFFER_SIZE
 
     def __init__(self, transformer: Transformer):
         self.transformer = transformer
@@ -500,7 +540,7 @@ class TextFormatter(Formatter):
         self._row_mode = True
 
         format_row_text(row, self.transformer, self._write_buffer)
-        if len(self._write_buffer) > self.BUFFER_SIZE:
+        if len(self._write_buffer) > BUFFER_SIZE:
             buffer, self._write_buffer = self._write_buffer, bytearray()
             return buffer
         else:
@@ -557,7 +597,7 @@ class BinaryFormatter(Formatter):
             self._signature_sent = True
 
         format_row_binary(row, self.transformer, self._write_buffer)
-        if len(self._write_buffer) > self.BUFFER_SIZE:
+        if len(self._write_buffer) > BUFFER_SIZE:
             buffer, self._write_buffer = self._write_buffer, bytearray()
             return buffer
         else:
index c02483a3bd880e0162af72b46964079778789ae0..e23c1baf4ec09577b52977945f9f7c4fbb28d85f 100644 (file)
@@ -592,11 +592,11 @@ def test_worker_life(conn, format, buffer):
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
     with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
-        assert not copy._worker
+        assert not copy.writer._worker
         copy.write(globals()[buffer])
-        assert copy._worker
+        assert copy.writer._worker
 
-    assert not copy._worker
+    assert not copy.writer._worker
     data = cur.execute("select * from copy_in order by 1").fetchall()
     assert data == sample_records