]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added parallel worker to COPY FROM
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 15 Jan 2021 13:55:29 +0000 (14:55 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 16 Jan 2021 01:06:55 +0000 (02:06 +0100)
Have a background writer to push data down the connection while the main
thread is CPU-bound to encode data.

Use a thread to implement the sync worker, a task to implement the
asyncio one.

psycopg3/psycopg3/copy.py
tests/test_copy.py
tests/test_copy_async.py

index 20c8beaa42206f3919fc74e69df1fd3da03cde18..ac0850f7e1ca945df140b12faf9ce155a6b1e1e9 100644 (file)
@@ -5,7 +5,10 @@ psycopg3 copy support
 # Copyright (C) 2020 The Psycopg Team
 
 import re
+import queue
 import struct
+import asyncio
+import threading
 from types import TracebackType
 from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union
 from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple
@@ -19,7 +22,7 @@ from .generators import copy_from, copy_to, copy_end
 
 if TYPE_CHECKING:
     from .pq.proto import PGresult
-    from .cursor import BaseCursor  # noqa: F401
+    from .cursor import BaseCursor, Cursor, AsyncCursor
     from .connection import Connection, AsyncConnection  # noqa: F401
 
 TEXT = pq.Format.TEXT
@@ -27,6 +30,26 @@ BINARY = pq.Format.BINARY
 
 
 class BaseCopy(Generic[ConnectionType]):
+    """
+    Base implementation for copy user interface
+
+    Two subclasses expose real methods with the sync/async differences.
+
+    The difference between the text and binary format is managed by two
+    different `Formatter` subclasses.
+
+    While the interface doesn't dictate it, both subclasses are implemented
+    with a worker to perform I/O related work, consuming the data provided in
+    the correct format from a queue, while the main thread is concerned with
+    formatting the data in copy format and adding it to the queue.
+    """
+
+    # Size of data to accumulate before sending it down the network
+    BUFFER_SIZE = 32 * 1024
+
+    # Max size of the write queue of buffers. More than that copy will block
+    QUEUE_SIZE = 1024
+
     def __init__(self, cursor: "BaseCursor[ConnectionType]"):
         self.cursor = cursor
         self.connection = cursor.connection
@@ -43,7 +66,6 @@ class BaseCopy(Generic[ConnectionType]):
         self._signature_sent = False
         self._row_mode = False  # true if the user is using send_row()
         self._write_buffer = bytearray()
-        self._write_buffer_size = 32 * 1024
         self._finished = False
 
         if self.format == TEXT:
@@ -104,13 +126,12 @@ class BaseCopy(Generic[ConnectionType]):
 
         return self._parse_row(data, self.transformer)
 
-    def _write_gen(self, buffer: Union[str, bytes]) -> PQGen[None]:
-        # if write() was called, assume the header was sent together with the
-        # first block of data.
+    def _format_write(self, buffer: Union[str, bytes]) -> bytes:
+        data = self._ensure_bytes(buffer)
         self._signature_sent = True
-        yield from copy_to(self._pgconn, self._ensure_bytes(buffer))
+        return data
 
-    def _write_row_gen(self, row: Sequence[Any]) -> PQGen[None]:
+    def _format_write_row(self, row: Sequence[Any]) -> bytes:
         # Note down that we are writing in row mode: it means we will have
         # to take care of the end-of-copy marker too
         self._row_mode = True
@@ -120,57 +141,46 @@ class BaseCopy(Generic[ConnectionType]):
             self._signature_sent = True
 
         self._format_row(row, self.transformer, self._write_buffer)
-        if len(self._write_buffer) > self._write_buffer_size:
-            yield from copy_to(self._pgconn, self._write_buffer)
-            self._write_buffer.clear()
-
-    def _finish_gen(self, error: str = "") -> PQGen[None]:
-        if error:
-            berr = error.encode(self.connection.client_encoding, "replace")
-            res = yield from copy_end(self._pgconn, berr)
+        if len(self._write_buffer) > self.BUFFER_SIZE:
+            buffer, self._write_buffer = self._write_buffer, bytearray()
+            return buffer
+        else:
+            return b""
+
+    def _format_end(self) -> bytes:
+        if self.format == BINARY:
+            # If we have sent no data we need to send the signature
+            # and the trailer
+            if not self._signature_sent:
+                self._write_buffer += _binary_signature
+                self._write_buffer += _binary_trailer
+            elif self._row_mode:
+
+                # if we have sent data already, we have sent the signature
+                # too (either with the first row, or we assume that in
+                # block mode the signature is included).
+                # Write the trailer only if we are sending rows (with the
+                # assumption that who is copying binary data is sending the
+                # whole format).
+                self._write_buffer += _binary_trailer
+
+        buffer, self._write_buffer = self._write_buffer, bytearray()
+        return buffer
+
+    def _end_copy_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
+        bmsg: Optional[bytes]
+        if exc:
+            msg = f"error from Python: {type(exc).__qualname__} - {exc}"
+            bmsg = msg.encode(self.connection.client_encoding, "replace")
         else:
-            if self.format == BINARY:
-                # If we have sent no data we need to send the signature
-                # and the trailer
-                if not self._signature_sent:
-                    self._write_buffer += _binary_signature
-                    self._write_buffer += _binary_trailer
-                elif self._row_mode:
-
-                    # if we have sent data already, we have sent the signature
-                    # too (either with the first row, or we assume that in
-                    # block mode the signature is included).
-                    # Write the trailer only if we are sending rows (with the
-                    # assumption that who is copying binary data is sending the
-                    # whole format).
-                    self._write_buffer += _binary_trailer
-
-            if self._write_buffer:
-                yield from copy_to(self._pgconn, self._write_buffer)
-                self._write_buffer.clear()
-            res = yield from copy_end(self._pgconn, None)
+            bmsg = None
+
+        res = yield from copy_end(self._pgconn, bmsg)
 
         nrows = res.command_tuples
         self.cursor._rowcount = nrows if nrows is not None else -1
         self._finished = True
 
-    def _exit_gen(
-        self,
-        exc_type: Optional[Type[BaseException]],
-        exc_val: Optional[BaseException],
-    ) -> PQGen[None]:
-        # no-op in COPY TO
-        if self._pgresult.status == ExecStatus.COPY_OUT:
-            return
-
-        # In case of error in Python let's quit it here
-        if exc_type:
-            yield from self._finish_gen(
-                f"error from Python: {exc_type.__qualname__} - {exc_val}"
-            )
-        else:
-            yield from self._finish_gen()
-
     # Support methods
 
     def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
@@ -197,6 +207,35 @@ class Copy(BaseCopy["Connection"]):
 
     __module__ = "psycopg3"
 
+    def __init__(self, cursor: "Cursor"):
+        super().__init__(cursor)
+        self._queue: queue.Queue[Optional[bytes]] = queue.Queue(
+            maxsize=self.QUEUE_SIZE
+        )
+        self._worker: Optional[threading.Thread] = None
+
+    def __enter__(self) -> "Copy":
+        self._check_reuse()
+        return self
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        self.finish(exc_val)
+
+    # End user sync interface
+
+    def __iter__(self) -> Iterator[memoryview]:
+        """Implement block-by-block iteration on :sql:`COPY TO`."""
+        while True:
+            data = self.read()
+            if not data:
+                break
+            yield data
+
     def read(self) -> memoryview:
         """
         Read an unparsed row after a :sql:`COPY TO` operation.
@@ -236,41 +275,96 @@ class Copy(BaseCopy["Connection"]):
         If the :sql:`COPY` is in binary format *buffer* must be `!bytes`. In
         text mode it can be either `!bytes` or `!str`.
         """
-        self.connection.wait(self._write_gen(buffer))
+        data = self._format_write(buffer)
+        self._write(data)
 
     def write_row(self, row: Sequence[Any]) -> None:
         """Write a record to a table after a :sql:`COPY FROM` operation."""
-        self.connection.wait(self._write_row_gen(row))
+        data = self._format_write_row(row)
+        self._write(data)
 
-    def _finish(self, error: str = "") -> None:
-        """Terminate a :sql:`COPY FROM` operation."""
-        self.connection.wait(self._finish_gen(error))
+    def finish(self, exc: Optional[BaseException]) -> None:
+        """Terminate the copy operation and free the resources allocated.
 
-    def __enter__(self) -> "Copy":
+        You shouldn't need to call this function yourself: it is usually called
+        by exit. It is available if, despite what is documented, you end up
+        using the `Copy` object outside a block.
+        """
+        # no-op in COPY TO
+        if self._pgresult.status == ExecStatus.COPY_OUT:
+            return
+
+        self._write_end()
+        self.connection.wait(self._end_copy_gen(exc))
+
+    # Concurrent copy support
+
+    def worker(self) -> None:
+        """Push data to the server when available from the copy queue.
+
+        Terminate reading when the queue receives a None.
+
+        The function is designed to be run in a separate thread.
+        """
+        while 1:
+            data = self._queue.get(block=True, timeout=24 * 60 * 60)
+            if not data:
+                break
+            self.connection.wait(copy_to(self._pgconn, data))
+
+    def _write(self, data: bytes) -> None:
+        if not data:
+            return
+
+        if not self._worker:
+            # warning: reference loop, broken by _write_end
+            self._worker = threading.Thread(target=self.worker)
+            self._worker.daemon = True
+            self._worker.start()
+
+        self._queue.put(data)
+
+    def _write_end(self) -> None:
+        data = self._format_end()
+        self._write(data)
+        self._queue.put(None)
+
+        if self._worker:
+            self._worker.join()
+            self._worker = None  # break the loop
+
+
+class AsyncCopy(BaseCopy["AsyncConnection"]):
+    """Manage an asynchronous :sql:`COPY` operation."""
+
+    __module__ = "psycopg3"
+
+    def __init__(self, cursor: "AsyncCursor"):
+        super().__init__(cursor)
+        self._queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue(
+            maxsize=self.QUEUE_SIZE
+        )
+        self._worker: Optional[asyncio.Future[None]] = None
+
+    async def __aenter__(self) -> "AsyncCopy":
         self._check_reuse()
         return self
 
-    def __exit__(
+    async def __aexit__(
         self,
         exc_type: Optional[Type[BaseException]],
         exc_val: Optional[BaseException],
         exc_tb: Optional[TracebackType],
     ) -> None:
-        self.connection.wait(self._exit_gen(exc_type, exc_val))
+        await self.finish(exc_val)
 
-    def __iter__(self) -> Iterator[memoryview]:
+    async def __aiter__(self) -> AsyncIterator[memoryview]:
         while True:
-            data = self.read()
+            data = await self.read()
             if not data:
                 break
             yield data
 
-
-class AsyncCopy(BaseCopy["AsyncConnection"]):
-    """Manage an asynchronous :sql:`COPY` operation."""
-
-    __module__ = "psycopg3"
-
     async def read(self) -> memoryview:
         return await self.connection.wait(self._read_gen())
 
@@ -285,32 +379,54 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
         return await self.connection.wait(self._read_row_gen())
 
     async def write(self, buffer: Union[str, bytes]) -> None:
-        await self.connection.wait(self._write_gen(buffer))
+        data = self._format_write(buffer)
+        await self._write(data)
 
     async def write_row(self, row: Sequence[Any]) -> None:
-        await self.connection.wait(self._write_row_gen(row))
+        data = self._format_write_row(row)
+        await self._write(data)
 
-    async def _finish(self, error: str = "") -> None:
-        await self.connection.wait(self._finish_gen(error))
+    async def finish(self, exc: Optional[BaseException]) -> None:
+        # no-op in COPY TO
+        if self._pgresult.status == ExecStatus.COPY_OUT:
+            return
 
-    async def __aenter__(self) -> "AsyncCopy":
-        self._check_reuse()
-        return self
+        await self._write_end()
+        await self.connection.wait(self._end_copy_gen(exc))
 
-    async def __aexit__(
-        self,
-        exc_type: Optional[Type[BaseException]],
-        exc_val: Optional[BaseException],
-        exc_tb: Optional[TracebackType],
-    ) -> None:
-        await self.connection.wait(self._exit_gen(exc_type, exc_val))
+    # Concurrent copy support
 
-    async def __aiter__(self) -> AsyncIterator[memoryview]:
-        while True:
-            data = await self.read()
+    async def worker(self) -> None:
+        """Push data to the server when available from the copy queue.
+
+        Terminate reading when the queue receives a None.
+
+        The function is designed to be run in a separate thread.
+        """
+        while 1:
+            data = await self._queue.get()
             if not data:
                 break
-            yield data
+            await self.connection.wait(copy_to(self._pgconn, data))
+
+    async def _write(self, data: bytes) -> None:
+        if not data:
+            return
+
+        if not self._worker:
+            # TODO: can be asyncio.create_task once Python 3.6 is dropped
+            self._worker = asyncio.ensure_future(self.worker())
+
+        await self._queue.put(data)
+
+    async def _write_end(self) -> None:
+        data = self._format_end()
+        await self._write(data)
+        await self._queue.put(None)
+
+        if self._worker:
+            await asyncio.gather(self._worker)
+            self._worker = None  # break reference loops if any
 
 
 def _format_row_text(
index 63e4242bceb3189c895526eea8a35cf6afa7a2ae..ea12f1e00faf0863c21cebf3b5e9387f0bceb17a 100644 (file)
@@ -479,6 +479,23 @@ def test_str(conn):
     assert "[INTRANS]" in str(copy)
 
 
+@pytest.mark.parametrize(
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+def test_worker_life(conn, format, buffer):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+        assert not copy._worker
+        copy.write(globals()[buffer])
+        assert copy._worker
+
+    assert not copy._worker
+    data = cur.execute("select * from copy_in order by 1").fetchall()
+    assert data == sample_records
+
+
 def py_to_raw(item, fmt):
     """Convert from Python type to the expected result from the db"""
     if fmt == Format.TEXT:
index 75855743c6cd8c8d3c1b654cd5cdbebcbe422413..6dced8dd9680f8d054ac4960caf5ab75f2b458d3 100644 (file)
@@ -452,6 +452,26 @@ async def test_str(aconn):
     assert "[INTRANS]" in str(copy)
 
 
+@pytest.mark.parametrize(
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+async def test_worker_life(aconn, format, buffer):
+    cur = await aconn.cursor()
+    await ensure_table(cur, sample_tabledef)
+    async with cur.copy(
+        f"copy copy_in from stdin (format {format.name})"
+    ) as copy:
+        assert not copy._worker
+        await copy.write(globals()[buffer])
+        assert copy._worker
+
+    assert not copy._worker
+    await cur.execute("select * from copy_in order by 1")
+    data = await cur.fetchall()
+    assert data == sample_records
+
+
 async def ensure_table(cur, tabledef, name="copy_in"):
     await cur.execute(f"drop table if exists {name}")
     await cur.execute(f"create table {name} ({tabledef})")