]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(copy): add ConnectionWriter
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Jun 2022 14:58:02 +0000 (16:58 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 19 Jul 2022 14:09:53 +0000 (15:09 +0100)
This writer allows direct writing to the libpw, without buffering queue,
and provides the interface for a writer using a libpq connection to
write.

psycopg/psycopg/copy.py
tests/test_copy.py

index 80d121f080a5ebd6f3abb9b0e31ec8abfd948f26..42ec488c9089911e528cc8537bcf2cbca92646e8 100644 (file)
@@ -202,15 +202,17 @@ class Copy(BaseCopy["Connection[Any]"]):
     __module__ = "psycopg"
 
     def __init__(
-        self, cursor: "Cursor[Any]", writer: Optional["Writer[Connection[Any]]"] = None
+        self,
+        cursor: "Cursor[Any]",
+        *,
+        writer: Optional["Writer[Connection[Any]]"] = None,
     ):
         super().__init__(cursor)
         if not writer:
-            writer = QueueWriter(cursor)
+            writer = QueueWriter(cursor.connection)
 
         self.writer = writer
         self._write = writer.write
-        self._write_end = writer.write_end
 
     def __enter__(self: BaseCopy._Self) -> BaseCopy._Self:
         self._enter()
@@ -292,7 +294,9 @@ class Copy(BaseCopy["Connection[Any]"]):
         """
         if self._pgresult.status == COPY_IN:
             data = self.formatter.end()
-            self._write_end(data)
+            if data:
+                self._write(data)
+            self.writer.finish()
             self.connection.wait(self._end_copy_in_gen(exc))
         else:
             self.connection.wait(self._end_copy_out_gen(exc))
@@ -404,20 +408,38 @@ class Writer(Generic[ConnectionType], ABC):
 
     @abstractmethod
     def write(self, data: Buffer) -> None:
+        """
+        Write some data to destination.
+        """
         ...
 
-    @abstractmethod
-    def write_end(self, data: Buffer) -> None:
-        ...
+    def finish(self) -> None:
+        """
+        Called when write operations are finished.
+        """
+        pass
 
 
-class ConnectionWriter(Writer[ConnectionType]):
-    def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"):
-        self.connection = cursor.connection
+class ConnectionWriter(Writer["Connection[Any]"]):
+    def __init__(self, connection: "Connection[Any]"):
+        self.connection = connection
         self._pgconn = self.connection.pgconn
 
+    def write(self, data: Buffer) -> None:
+        if len(data) <= MAX_BUFFER_SIZE:
+            # Most used path: we don't need to split the buffer in smaller
+            # bits, so don't make a copy.
+            self.connection.wait(copy_to(self._pgconn, data))
+        else:
+            # Copy a buffer too large in chunks to avoid causing a memory
+            # error in the libpq, which may cause an infinite loop (#255).
+            for i in range(0, len(data), MAX_BUFFER_SIZE):
+                self.connection.wait(
+                    copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
+                )
+
 
-class QueueWriter(ConnectionWriter["Connection[Any]"]):
+class QueueWriter(ConnectionWriter):
     """
     A writer using a buffer to queue data to write.
 
@@ -426,8 +448,8 @@ class QueueWriter(ConnectionWriter["Connection[Any]"]):
     on the connection.
     """
 
-    def __init__(self, cursor: "Cursor[Any]"):
-        super().__init__(cursor)
+    def __init__(self, connection: "Connection[Any]"):
+        super().__init__(connection)
 
         self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE)
         self._worker: Optional[threading.Thread] = None
@@ -472,9 +494,7 @@ class QueueWriter(ConnectionWriter["Connection[Any]"]):
             for i in range(0, len(data), MAX_BUFFER_SIZE):
                 self._queue.put(data[i : i + MAX_BUFFER_SIZE])
 
-    def write_end(self, data: Buffer) -> None:
-        if data:
-            self.write(data)
+    def finish(self) -> None:
         self._queue.put(b"")
 
         if self._worker:
index e23c1baf4ec09577b52977945f9f7c4fbb28d85f..25bc1a5dd2dc3b07ca3f5581e6b58fabb4fac7a9 100644 (file)
@@ -614,6 +614,25 @@ def test_worker_error_propagated(conn, monkeypatch):
             copy.write("a,b")
 
 
+@pytest.mark.parametrize(
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+def test_connection_writer(conn, format, buffer):
+    cur = conn.cursor()
+    writer = psycopg.copy.ConnectionWriter(conn)
+
+    ensure_table(cur, sample_tabledef)
+    with cur.copy(
+        f"copy copy_in from stdin (format {format.name})", writer=writer
+    ) as copy:
+        assert copy.writer is writer
+        copy.write(globals()[buffer])
+
+    data = cur.execute("select * from copy_in order by 1").fetchall()
+    assert data == sample_records
+
+
 @pytest.mark.slow
 @pytest.mark.parametrize(
     "fmt, set_types",