From: Daniele Varrazzo Date: Sat, 11 Jun 2022 14:58:02 +0000 (+0200) Subject: feat(copy): add ConnectionWriter X-Git-Tag: 3.1~44^2~11 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f450693921628bece1db461fd318bf70202b3b41;p=thirdparty%2Fpsycopg.git feat(copy): add ConnectionWriter This writer allows direct writing to the libpw, without buffering queue, and provides the interface for a writer using a libpq connection to write. --- diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index 80d121f08..42ec488c9 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -202,15 +202,17 @@ class Copy(BaseCopy["Connection[Any]"]): __module__ = "psycopg" def __init__( - self, cursor: "Cursor[Any]", writer: Optional["Writer[Connection[Any]]"] = None + self, + cursor: "Cursor[Any]", + *, + writer: Optional["Writer[Connection[Any]]"] = None, ): super().__init__(cursor) if not writer: - writer = QueueWriter(cursor) + writer = QueueWriter(cursor.connection) self.writer = writer self._write = writer.write - self._write_end = writer.write_end def __enter__(self: BaseCopy._Self) -> BaseCopy._Self: self._enter() @@ -292,7 +294,9 @@ class Copy(BaseCopy["Connection[Any]"]): """ if self._pgresult.status == COPY_IN: data = self.formatter.end() - self._write_end(data) + if data: + self._write(data) + self.writer.finish() self.connection.wait(self._end_copy_in_gen(exc)) else: self.connection.wait(self._end_copy_out_gen(exc)) @@ -404,20 +408,38 @@ class Writer(Generic[ConnectionType], ABC): @abstractmethod def write(self, data: Buffer) -> None: + """ + Write some data to destination. + """ ... - @abstractmethod - def write_end(self, data: Buffer) -> None: - ... + def finish(self) -> None: + """ + Called when write operations are finished. + """ + pass -class ConnectionWriter(Writer[ConnectionType]): - def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"): - self.connection = cursor.connection +class ConnectionWriter(Writer["Connection[Any]"]): + def __init__(self, connection: "Connection[Any]"): + self.connection = connection self._pgconn = self.connection.pgconn + def write(self, data: Buffer) -> None: + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + self.connection.wait(copy_to(self._pgconn, data)) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + self.connection.wait( + copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) + ) + -class QueueWriter(ConnectionWriter["Connection[Any]"]): +class QueueWriter(ConnectionWriter): """ A writer using a buffer to queue data to write. @@ -426,8 +448,8 @@ class QueueWriter(ConnectionWriter["Connection[Any]"]): on the connection. """ - def __init__(self, cursor: "Cursor[Any]"): - super().__init__(cursor) + def __init__(self, connection: "Connection[Any]"): + super().__init__(connection) self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE) self._worker: Optional[threading.Thread] = None @@ -472,9 +494,7 @@ class QueueWriter(ConnectionWriter["Connection[Any]"]): for i in range(0, len(data), MAX_BUFFER_SIZE): self._queue.put(data[i : i + MAX_BUFFER_SIZE]) - def write_end(self, data: Buffer) -> None: - if data: - self.write(data) + def finish(self) -> None: self._queue.put(b"") if self._worker: diff --git a/tests/test_copy.py b/tests/test_copy.py index e23c1baf4..25bc1a5dd 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -614,6 +614,25 @@ def test_worker_error_propagated(conn, monkeypatch): copy.write("a,b") +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +def test_connection_writer(conn, format, buffer): + cur = conn.cursor() + writer = psycopg.copy.ConnectionWriter(conn) + + ensure_table(cur, sample_tabledef) + with cur.copy( + f"copy copy_in from stdin (format {format.name})", writer=writer + ) as copy: + assert copy.writer is writer + copy.write(globals()[buffer]) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + @pytest.mark.slow @pytest.mark.parametrize( "fmt, set_types",