]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(copy): add writer param AsyncCursor.copy()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Jun 2022 15:19:46 +0000 (17:19 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 19 Jul 2022 14:09:53 +0000 (15:09 +0100)
Add async copy writers similar to the sync ones.

psycopg/psycopg/copy.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
tests/test_copy_async.py

index 42ec488c9089911e528cc8537bcf2cbca92646e8..b9d641921bcc7a7c139d0613e3ecf6d41dadb133 100644 (file)
@@ -77,7 +77,6 @@ class BaseCopy(Generic[ConnectionType]):
     _Self = TypeVar("_Self", bound="BaseCopy[ConnectionType]")
 
     formatter: "Formatter"
-    writer: "Writer[ConnectionType]"
 
     def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"):
         self.cursor = cursor
@@ -201,12 +200,9 @@ class Copy(BaseCopy["Connection[Any]"]):
 
     __module__ = "psycopg"
 
-    def __init__(
-        self,
-        cursor: "Cursor[Any]",
-        *,
-        writer: Optional["Writer[Connection[Any]]"] = None,
-    ):
+    writer: "Writer"
+
+    def __init__(self, cursor: "Cursor[Any]", *, writer: Optional["Writer"] = None):
         super().__init__(cursor)
         if not writer:
             writer = QueueWriter(cursor.connection)
@@ -302,15 +298,128 @@ class Copy(BaseCopy["Connection[Any]"]):
             self.connection.wait(self._end_copy_out_gen(exc))
 
 
+class Writer(ABC):
+    """
+    A class to write copy data somewhere.
+    """
+
+    @abstractmethod
+    def write(self, data: Buffer) -> None:
+        """
+        Write some data to destination.
+        """
+        ...
+
+    def finish(self) -> None:
+        """
+        Called when write operations are finished.
+        """
+        pass
+
+
+class ConnectionWriter(Writer):
+    def __init__(self, connection: "Connection[Any]"):
+        self.connection = connection
+        self._pgconn = self.connection.pgconn
+
+    def write(self, data: Buffer) -> None:
+        if len(data) <= MAX_BUFFER_SIZE:
+            # Most used path: we don't need to split the buffer in smaller
+            # bits, so don't make a copy.
+            self.connection.wait(copy_to(self._pgconn, data))
+        else:
+            # Copy a buffer too large in chunks to avoid causing a memory
+            # error in the libpq, which may cause an infinite loop (#255).
+            for i in range(0, len(data), MAX_BUFFER_SIZE):
+                self.connection.wait(
+                    copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
+                )
+
+
+class QueueWriter(ConnectionWriter):
+    """
+    A writer using a buffer to queue data to write.
+
+    `write()` returns immediately, so that the main thread can be CPU-bound
+    formatting messages, while a worker thread can be IO-bound waiting to write
+    on the connection.
+    """
+
+    def __init__(self, connection: "Connection[Any]"):
+        super().__init__(connection)
+
+        self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE)
+        self._worker: Optional[threading.Thread] = None
+        self._worker_error: Optional[BaseException] = None
+
+    def worker(self) -> None:
+        """Push data to the server when available from the copy queue.
+
+        Terminate reading when the queue receives a false-y value, or in case
+        of error.
+
+        The function is designed to be run in a separate thread.
+        """
+        try:
+            while True:
+                data = self._queue.get(block=True, timeout=24 * 60 * 60)
+                if not data:
+                    break
+                self.connection.wait(copy_to(self._pgconn, data))
+        except BaseException as ex:
+            # Propagate the error to the main thread.
+            self._worker_error = ex
+
+    def write(self, data: Buffer) -> None:
+        if not self._worker:
+            # warning: reference loop, broken by _write_end
+            self._worker = threading.Thread(target=self.worker)
+            self._worker.daemon = True
+            self._worker.start()
+
+        # If the worker thread raies an exception, re-raise it to the caller.
+        if self._worker_error:
+            raise self._worker_error
+
+        if len(data) <= MAX_BUFFER_SIZE:
+            # Most used path: we don't need to split the buffer in smaller
+            # bits, so don't make a copy.
+            self._queue.put(data)
+        else:
+            # Copy a buffer too large in chunks to avoid causing a memory
+            # error in the libpq, which may cause an infinite loop (#255).
+            for i in range(0, len(data), MAX_BUFFER_SIZE):
+                self._queue.put(data[i : i + MAX_BUFFER_SIZE])
+
+    def finish(self) -> None:
+        self._queue.put(b"")
+
+        if self._worker:
+            self._worker.join()
+            self._worker = None  # break the loop
+
+        # Check if the worker thread raised any exception before terminating.
+        if self._worker_error:
+            raise self._worker_error
+
+
 class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
     """Manage an asynchronous :sql:`COPY` operation."""
 
     __module__ = "psycopg"
 
-    def __init__(self, cursor: "AsyncCursor[Any]"):
+    writer: "AsyncWriter"
+
+    def __init__(
+        self, cursor: "AsyncCursor[Any]", *, writer: Optional["AsyncWriter"] = None
+    ):
         super().__init__(cursor)
-        self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=QUEUE_SIZE)
-        self._worker: Optional[asyncio.Future[None]] = None
+
+        if not writer:
+            writer = AsyncQueueWriter(cursor.connection)
+
+        self.writer = writer
+        self._write = writer.write
 
     async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self:
         self._enter()
@@ -356,90 +465,54 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
 
     async def finish(self, exc: Optional[BaseException]) -> None:
         if self._pgresult.status == COPY_IN:
-            await self._write_end()
+            data = self.formatter.end()
+            if data:
+                await self._write(data)
+            await self.writer.finish()
             await self.connection.wait(self._end_copy_in_gen(exc))
         else:
             await self.connection.wait(self._end_copy_out_gen(exc))
 
-    # Concurrent copy support
 
-    async def worker(self) -> None:
-        """Push data to the server when available from the copy queue.
-
-        Terminate reading when the queue receives a false-y value.
-
-        The function is designed to be run in a separate thread.
-        """
-        while True:
-            data = await self._queue.get()
-            if not data:
-                break
-            await self.connection.wait(copy_to(self._pgconn, data))
-
-    async def _write(self, data: Buffer) -> None:
-        if not self._worker:
-            self._worker = create_task(self.worker())
-
-        if len(data) <= MAX_BUFFER_SIZE:
-            # Most used path: we don't need to split the buffer in smaller
-            # bits, so don't make a copy.
-            await self._queue.put(data)
-        else:
-            # Copy a buffer too large in chunks to avoid causing a memory
-            # error in the libpq, which may cause an infinite loop (#255).
-            for i in range(0, len(data), MAX_BUFFER_SIZE):
-                await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
-
-    async def _write_end(self) -> None:
-        data = self.formatter.end()
-        if data:
-            await self._write(data)
-        await self._queue.put(b"")
-
-        if self._worker:
-            await asyncio.gather(self._worker)
-            self._worker = None  # break reference loops if any
-
-
-class Writer(Generic[ConnectionType], ABC):
+class AsyncWriter(ABC):
     """
     A class to write copy data somewhere.
     """
 
     @abstractmethod
-    def write(self, data: Buffer) -> None:
+    async def write(self, data: Buffer) -> None:
         """
         Write some data to destination.
         """
         ...
 
-    def finish(self) -> None:
+    async def finish(self) -> None:
         """
         Called when write operations are finished.
         """
         pass
 
 
-class ConnectionWriter(Writer["Connection[Any]"]):
-    def __init__(self, connection: "Connection[Any]"):
+class AsyncConnectionWriter(AsyncWriter):
+    def __init__(self, connection: "AsyncConnection[Any]"):
         self.connection = connection
         self._pgconn = self.connection.pgconn
 
-    def write(self, data: Buffer) -> None:
+    async def write(self, data: Buffer) -> None:
         if len(data) <= MAX_BUFFER_SIZE:
             # Most used path: we don't need to split the buffer in smaller
             # bits, so don't make a copy.
-            self.connection.wait(copy_to(self._pgconn, data))
+            await self.connection.wait(copy_to(self._pgconn, data))
         else:
             # Copy a buffer too large in chunks to avoid causing a memory
             # error in the libpq, which may cause an infinite loop (#255).
             for i in range(0, len(data), MAX_BUFFER_SIZE):
-                self.connection.wait(
+                await self.connection.wait(
                     copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
                 )
 
 
-class QueueWriter(ConnectionWriter):
+class AsyncQueueWriter(AsyncConnectionWriter):
     """
     A writer using a buffer to queue data to write.
 
@@ -448,62 +521,45 @@ class QueueWriter(ConnectionWriter):
     on the connection.
     """
 
-    def __init__(self, connection: "Connection[Any]"):
+    def __init__(self, connection: "AsyncConnection[Any]"):
         super().__init__(connection)
 
-        self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE)
-        self._worker: Optional[threading.Thread] = None
-        self._worker_error: Optional[BaseException] = None
+        self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=QUEUE_SIZE)
+        self._worker: Optional[asyncio.Future[None]] = None
 
-    def worker(self) -> None:
+    async def worker(self) -> None:
         """Push data to the server when available from the copy queue.
 
-        Terminate reading when the queue receives a false-y value, or in case
-        of error.
+        Terminate reading when the queue receives a false-y value.
 
-        The function is designed to be run in a separate thread.
+        The function is designed to be run in a separate task.
         """
-        try:
-            while True:
-                data = self._queue.get(block=True, timeout=24 * 60 * 60)
-                if not data:
-                    break
-                self.connection.wait(copy_to(self._pgconn, data))
-        except BaseException as ex:
-            # Propagate the error to the main thread.
-            self._worker_error = ex
+        while True:
+            data = await self._queue.get()
+            if not data:
+                break
+            await self.connection.wait(copy_to(self._pgconn, data))
 
-    def write(self, data: Buffer) -> None:
+    async def write(self, data: Buffer) -> None:
         if not self._worker:
-            # warning: reference loop, broken by _write_end
-            self._worker = threading.Thread(target=self.worker)
-            self._worker.daemon = True
-            self._worker.start()
-
-        # If the worker thread raies an exception, re-raise it to the caller.
-        if self._worker_error:
-            raise self._worker_error
+            self._worker = create_task(self.worker())
 
         if len(data) <= MAX_BUFFER_SIZE:
             # Most used path: we don't need to split the buffer in smaller
             # bits, so don't make a copy.
-            self._queue.put(data)
+            await self._queue.put(data)
         else:
             # Copy a buffer too large in chunks to avoid causing a memory
             # error in the libpq, which may cause an infinite loop (#255).
             for i in range(0, len(data), MAX_BUFFER_SIZE):
-                self._queue.put(data[i : i + MAX_BUFFER_SIZE])
+                await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
 
-    def finish(self) -> None:
-        self._queue.put(b"")
+    async def finish(self) -> None:
+        await self._queue.put(b"")
 
         if self._worker:
-            self._worker.join()
-            self._worker = None  # break the loop
-
-        # Check if the worker thread raised any exception before terminating.
-        if self._worker_error:
-            raise self._worker_error
+            await asyncio.gather(self._worker)
+            self._worker = None  # break reference loops if any
 
 
 class Formatter(ABC):
index bab46c2d9d09f435dbdfcdb13d48863dcaaf7915..77ad80788eb06b25ba35324f2f70aa46113888b4 100644 (file)
@@ -878,7 +878,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         statement: Query,
         params: Optional[Params] = None,
         *,
-        writer: Optional[CopyWriter[Any]] = None,
+        writer: Optional[CopyWriter] = None,
     ) -> Iterator[Copy]:
         """
         Initiate a :sql:`COPY` operation and return an object to manage it.
index 7f71efaad887f7212364edbc79bf0e59aaa8e265..4b4844d40f4e0a64f6dbf9060d249a7b98e18154 100644 (file)
@@ -12,7 +12,7 @@ from contextlib import asynccontextmanager
 from . import pq
 from . import errors as e
 from .abc import Query, Params
-from .copy import AsyncCopy
+from .copy import AsyncCopy, AsyncWriter as AsyncCopyWriter
 from .rows import Row, RowMaker, AsyncRowFactory
 from .cursor import BaseCursor
 from ._pipeline import Pipeline
@@ -204,7 +204,11 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
 
     @asynccontextmanager
     async def copy(
-        self, statement: Query, params: Optional[Params] = None
+        self,
+        statement: Query,
+        params: Optional[Params] = None,
+        *,
+        writer: Optional[AsyncCopyWriter] = None,
     ) -> AsyncIterator[AsyncCopy]:
         """
         :rtype: AsyncCopy
@@ -213,7 +217,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
             async with self._conn.lock:
                 await self._conn.wait(self._start_copy_gen(statement, params))
 
-            async with AsyncCopy(self) as copy:
+            async with AsyncCopy(self, writer=writer) as copy:
                 yield copy
         except e.Error as ex:
             raise ex.with_traceback(None)
index 3b0008dde1d6a7c41ae0ee1d3c93fa251a408623..045ad85144f16f2a18cf1a8cc7d5024ce2947dfd 100644 (file)
@@ -596,11 +596,11 @@ async def test_worker_life(aconn, format, buffer):
     cur = aconn.cursor()
     await ensure_table(cur, sample_tabledef)
     async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
-        assert not copy._worker
+        assert not copy.writer._worker
         await copy.write(globals()[buffer])
-        assert copy._worker
+        assert copy.writer._worker
 
-    assert not copy._worker
+    assert not copy.writer._worker
     await cur.execute("select * from copy_in order by 1")
     data = await cur.fetchall()
     assert data == sample_records
@@ -619,6 +619,26 @@ async def test_worker_error_propagated(aconn, monkeypatch):
             await copy.write("a,b")
 
 
+@pytest.mark.parametrize(
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+async def test_connection_writer(aconn, format, buffer):
+    cur = aconn.cursor()
+    writer = psycopg.copy.AsyncConnectionWriter(aconn)
+
+    await ensure_table(cur, sample_tabledef)
+    async with cur.copy(
+        f"copy copy_in from stdin (format {format.name})", writer=writer
+    ) as copy:
+        assert copy.writer is writer
+        await copy.write(globals()[buffer])
+
+    await cur.execute("select * from copy_in order by 1")
+    data = await cur.fetchall()
+    assert data == sample_records
+
+
 @pytest.mark.slow
 @pytest.mark.parametrize(
     "fmt, set_types",