]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(copy): make the writer entirely responsible of the libpq copy state
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Jun 2022 22:47:50 +0000 (00:47 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 19 Jul 2022 14:09:53 +0000 (15:09 +0100)
This way we can create a writer which is entirely independent from the
libpq and the connection, which is useful, for instance, to format a
file with copy data.

psycopg/psycopg/copy.py
tests/test_copy.py
tests/test_copy_async.py

index b9d641921bcc7a7c139d0613e3ecf6d41dadb133..6ac502500c6bdb7f54dc3444bf9230792b80575d 100644 (file)
@@ -15,28 +15,28 @@ from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match
 from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
 
 from . import pq
+from . import adapt
 from . import errors as e
 from .abc import Buffer, ConnectionType, PQGen, Transformer
-from .adapt import PyFormat
 from ._compat import create_task
 from ._cmodule import _psycopg
 from ._encodings import pgconn_encoding
 from .generators import copy_from, copy_to, copy_end
 
 if TYPE_CHECKING:
-    from .pq.abc import PGresult
     from .cursor import BaseCursor, Cursor
     from .cursor_async import AsyncCursor
     from .connection import Connection  # noqa: F401
     from .connection_async import AsyncConnection  # noqa: F401
 
-PY_TEXT = PyFormat.TEXT
-PY_BINARY = PyFormat.BINARY
+PY_TEXT = adapt.PyFormat.TEXT
+PY_BINARY = adapt.PyFormat.BINARY
 
 TEXT = pq.Format.TEXT
 BINARY = pq.Format.BINARY
 
 COPY_IN = pq.ExecStatus.COPY_IN
+COPY_OUT = pq.ExecStatus.COPY_OUT
 
 ACTIVE = pq.TransactionStatus.ACTIVE
 
@@ -83,14 +83,22 @@ class BaseCopy(Generic[ConnectionType]):
         self.connection = cursor.connection
         self._pgconn = self.connection.pgconn
 
-        tx = cursor._tx
-        assert tx.pgresult, "The Transformer doesn't have a PGresult set"
-        self._pgresult: "PGresult" = tx.pgresult
-
-        if self._pgresult.binary_tuples == TEXT:
-            self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
+        result = cursor.pgresult
+        if result:
+            self._direction = result.status
+            if self._direction != COPY_IN and self._direction != COPY_OUT:
+                raise e.ProgrammingError(
+                    "the cursor should have performed a COPY operation;"
+                    f" its status is {pq.ExecStatus(self._direction).name} instead"
+                )
         else:
+            self._direction = COPY_IN
+
+        tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor)
+        if result and result.binary_tuples:
             self.formatter = BinaryFormatter(tx)
+        else:
+            self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
 
         self._finished = False
 
@@ -125,7 +133,7 @@ class BaseCopy(Generic[ConnectionType]):
         registry = self.cursor.adapters.types
         oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types]
 
-        if self._pgresult.status == COPY_IN:
+        if self._direction == COPY_IN:
             self.formatter.transformer.set_dumper_types(oids, self.formatter.format)
         else:
             self.formatter.transformer.set_loader_types(oids, self.formatter.format)
@@ -160,20 +168,6 @@ class BaseCopy(Generic[ConnectionType]):
 
         return row
 
-    def _end_copy_in_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
-        bmsg: Optional[bytes]
-        if exc:
-            msg = f"error from Python: {type(exc).__qualname__} - {exc}"
-            bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
-        else:
-            bmsg = None
-
-        res = yield from copy_end(self._pgconn, bmsg)
-
-        nrows = res.command_tuples
-        self.cursor._rowcount = nrows if nrows is not None else -1
-        self._finished = True
-
     def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
         if not exc:
             return
@@ -205,7 +199,7 @@ class Copy(BaseCopy["Connection[Any]"]):
     def __init__(self, cursor: "Cursor[Any]", *, writer: Optional["Writer"] = None):
         super().__init__(cursor)
         if not writer:
-            writer = QueueWriter(cursor.connection)
+            writer = QueueWriter(cursor)
 
         self.writer = writer
         self._write = writer.write
@@ -288,12 +282,12 @@ class Copy(BaseCopy["Connection[Any]"]):
         by exit. It is available if, despite what is documented, you end up
         using the `Copy` object outside a block.
         """
-        if self._pgresult.status == COPY_IN:
+        if self._direction == COPY_IN:
             data = self.formatter.end()
             if data:
                 self._write(data)
-            self.writer.finish()
-            self.connection.wait(self._end_copy_in_gen(exc))
+            self.writer.finish(exc)
+            self._finished = True
         else:
             self.connection.wait(self._end_copy_out_gen(exc))
 
@@ -310,7 +304,7 @@ class Writer(ABC):
         """
         ...
 
-    def finish(self) -> None:
+    def finish(self, exc: Optional[BaseException] = None) -> None:
         """
         Called when write operations are finished.
         """
@@ -318,8 +312,9 @@ class Writer(ABC):
 
 
 class ConnectionWriter(Writer):
-    def __init__(self, connection: "Connection[Any]"):
-        self.connection = connection
+    def __init__(self, cursor: "Cursor[Any]"):
+        self.cursor = cursor
+        self.connection = cursor.connection
         self._pgconn = self.connection.pgconn
 
     def write(self, data: Buffer) -> None:
@@ -335,6 +330,18 @@ class ConnectionWriter(Writer):
                     copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
                 )
 
+    def finish(self, exc: Optional[BaseException] = None) -> None:
+        bmsg: Optional[bytes]
+        if exc:
+            msg = f"error from Python: {type(exc).__qualname__} - {exc}"
+            bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
+        else:
+            bmsg = None
+
+        res = self.connection.wait(copy_end(self._pgconn, bmsg))
+        nrows = res.command_tuples
+        self.cursor._rowcount = nrows if nrows is not None else -1
+
 
 class QueueWriter(ConnectionWriter):
     """
@@ -345,8 +352,8 @@ class QueueWriter(ConnectionWriter):
     on the connection.
     """
 
-    def __init__(self, connection: "Connection[Any]"):
-        super().__init__(connection)
+    def __init__(self, cursor: "Cursor[Any]"):
+        super().__init__(cursor)
 
         self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE)
         self._worker: Optional[threading.Thread] = None
@@ -391,7 +398,7 @@ class QueueWriter(ConnectionWriter):
             for i in range(0, len(data), MAX_BUFFER_SIZE):
                 self._queue.put(data[i : i + MAX_BUFFER_SIZE])
 
-    def finish(self) -> None:
+    def finish(self, exc: Optional[BaseException] = None) -> None:
         self._queue.put(b"")
 
         if self._worker:
@@ -402,6 +409,8 @@ class QueueWriter(ConnectionWriter):
         if self._worker_error:
             raise self._worker_error
 
+        super().finish(exc)
+
 
 class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
     """Manage an asynchronous :sql:`COPY` operation."""
@@ -416,7 +425,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
         super().__init__(cursor)
 
         if not writer:
-            writer = AsyncQueueWriter(cursor.connection)
+            writer = AsyncQueueWriter(cursor)
 
         self.writer = writer
         self._write = writer.write
@@ -464,12 +473,12 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
             await self._write(data)
 
     async def finish(self, exc: Optional[BaseException]) -> None:
-        if self._pgresult.status == COPY_IN:
+        if self._direction == COPY_IN:
             data = self.formatter.end()
             if data:
                 await self._write(data)
-            await self.writer.finish()
-            await self.connection.wait(self._end_copy_in_gen(exc))
+            await self.writer.finish(exc)
+            self._finished = True
         else:
             await self.connection.wait(self._end_copy_out_gen(exc))
 
@@ -486,7 +495,7 @@ class AsyncWriter(ABC):
         """
         ...
 
-    async def finish(self) -> None:
+    async def finish(self, exc: Optional[BaseException] = None) -> None:
         """
         Called when write operations are finished.
         """
@@ -494,8 +503,9 @@ class AsyncWriter(ABC):
 
 
 class AsyncConnectionWriter(AsyncWriter):
-    def __init__(self, connection: "AsyncConnection[Any]"):
-        self.connection = connection
+    def __init__(self, cursor: "AsyncCursor[Any]"):
+        self.cursor = cursor
+        self.connection = cursor.connection
         self._pgconn = self.connection.pgconn
 
     async def write(self, data: Buffer) -> None:
@@ -511,6 +521,18 @@ class AsyncConnectionWriter(AsyncWriter):
                     copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
                 )
 
+    async def finish(self, exc: Optional[BaseException] = None) -> None:
+        bmsg: Optional[bytes]
+        if exc:
+            msg = f"error from Python: {type(exc).__qualname__} - {exc}"
+            bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
+        else:
+            bmsg = None
+
+        res = await self.connection.wait(copy_end(self._pgconn, bmsg))
+        nrows = res.command_tuples
+        self.cursor._rowcount = nrows if nrows is not None else -1
+
 
 class AsyncQueueWriter(AsyncConnectionWriter):
     """
@@ -521,8 +543,8 @@ class AsyncQueueWriter(AsyncConnectionWriter):
     on the connection.
     """
 
-    def __init__(self, connection: "AsyncConnection[Any]"):
-        super().__init__(connection)
+    def __init__(self, cursor: "AsyncCursor[Any]"):
+        super().__init__(cursor)
 
         self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=QUEUE_SIZE)
         self._worker: Optional[asyncio.Future[None]] = None
@@ -554,13 +576,15 @@ class AsyncQueueWriter(AsyncConnectionWriter):
             for i in range(0, len(data), MAX_BUFFER_SIZE):
                 await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
 
-    async def finish(self) -> None:
+    async def finish(self, exc: Optional[BaseException] = None) -> None:
         await self._queue.put(b"")
 
         if self._worker:
             await asyncio.gather(self._worker)
             self._worker = None  # break reference loops if any
 
+        await super().finish(exc)
+
 
 class Formatter(ABC):
     """
index 25bc1a5dd2dc3b07ca3f5581e6b58fabb4fac7a9..78a7ebec71941b05628272fd35c6e7e35888e240 100644 (file)
@@ -46,6 +46,8 @@ sample_binary_rows = [
 
 sample_binary = b"".join(sample_binary_rows)
 
+special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"}
+
 
 @pytest.mark.parametrize("format", Format)
 def test_copy_out_read(conn, format):
@@ -458,6 +460,29 @@ from copy_in group by 1, 2, 3
     assert data == [(True, True, 1, 256)]
 
 
+def test_copy_in_format(conn):
+    writer = BytesWriter()
+    conn.execute("set client_encoding to utf8")
+    cur = conn.cursor()
+    with psycopg.copy.Copy(cur, writer=writer) as copy:
+        for i in range(1, 256):
+            copy.write_row((i, chr(i)))
+
+    writer.file.seek(0)
+    rows = writer.file.read().split(b"\n")
+    assert not rows[-1]
+    del rows[-1]
+
+    for i, row in enumerate(rows, start=1):
+        fields = row.split(b"\t")
+        assert len(fields) == 2
+        assert int(fields[0].decode()) == i
+        if i in special_chars:
+            assert fields[1].decode() == f"\\{special_chars[i]}"
+        else:
+            assert fields[1].decode() == chr(i)
+
+
 @pytest.mark.slow
 def test_copy_from_to(conn):
     # Roundtrip from file to database to file blockwise
@@ -620,7 +645,7 @@ def test_worker_error_propagated(conn, monkeypatch):
 )
 def test_connection_writer(conn, format, buffer):
     cur = conn.cursor()
-    writer = psycopg.copy.ConnectionWriter(conn)
+    writer = psycopg.copy.ConnectionWriter(cur)
 
     ensure_table(cur, sample_tabledef)
     with cur.copy(
@@ -832,3 +857,11 @@ class DataGenerator:
                 block = block.encode()
             m.update(block)
         return m.hexdigest()
+
+
+class BytesWriter(psycopg.copy.Writer):
+    def __init__(self):
+        self.file = BytesIO()
+
+    def write(self, data):
+        self.file.write(data)
index 045ad85144f16f2a18cf1a8cc7d5024ce2947dfd..bbaf77aaa5b3a7459933f98f75062552492595cd 100644 (file)
@@ -20,7 +20,7 @@ from psycopg.types.numeric import Int4
 from .utils import alist, eur, gc_collect
 from .test_copy import sample_text, sample_binary, sample_binary_rows  # noqa
 from .test_copy import sample_values, sample_records, sample_tabledef
-from .test_copy import py_to_raw
+from .test_copy import py_to_raw, special_chars
 
 pytestmark = [
     pytest.mark.asyncio,
@@ -462,6 +462,29 @@ from copy_in group by 1, 2, 3
     assert data == [(True, True, 1, 256)]
 
 
+async def test_copy_in_format(aconn):
+    writer = AsyncBytesWriter()
+    await aconn.execute("set client_encoding to utf8")
+    cur = aconn.cursor()
+    async with psycopg.copy.AsyncCopy(cur, writer=writer) as copy:
+        for i in range(1, 256):
+            await copy.write_row((i, chr(i)))
+
+    writer.file.seek(0)
+    rows = writer.file.read().split(b"\n")
+    assert not rows[-1]
+    del rows[-1]
+
+    for i, row in enumerate(rows, start=1):
+        fields = row.split(b"\t")
+        assert len(fields) == 2
+        assert int(fields[0].decode()) == i
+        if i in special_chars:
+            assert fields[1].decode() == f"\\{special_chars[i]}"
+        else:
+            assert fields[1].decode() == chr(i)
+
+
 @pytest.mark.slow
 async def test_copy_from_to(aconn):
     # Roundtrip from file to database to file blockwise
@@ -625,7 +648,7 @@ async def test_worker_error_propagated(aconn, monkeypatch):
 )
 async def test_connection_writer(aconn, format, buffer):
     cur = aconn.cursor()
-    writer = psycopg.copy.AsyncConnectionWriter(aconn)
+    writer = psycopg.copy.AsyncConnectionWriter(cur)
 
     await ensure_table(cur, sample_tabledef)
     async with cur.copy(
@@ -827,3 +850,11 @@ class DataGenerator:
                 block = block.encode()
             m.update(block)
         return m.hexdigest()
+
+
+class AsyncBytesWriter(psycopg.copy.AsyncWriter):
+    def __init__(self):
+        self.file = BytesIO()
+
+    async def write(self, data):
+        self.file.write(data)