__module__ = "psycopg"
def __init__(
- self, cursor: "Cursor[Any]", writer: Optional["Writer[Connection[Any]]"] = None
+ self,
+ cursor: "Cursor[Any]",
+ *,
+ writer: Optional["Writer[Connection[Any]]"] = None,
):
super().__init__(cursor)
if not writer:
- writer = QueueWriter(cursor)
+ writer = QueueWriter(cursor.connection)
self.writer = writer
self._write = writer.write
- self._write_end = writer.write_end
def __enter__(self: BaseCopy._Self) -> BaseCopy._Self:
self._enter()
"""
if self._pgresult.status == COPY_IN:
data = self.formatter.end()
- self._write_end(data)
+ if data:
+ self._write(data)
+ self.writer.finish()
self.connection.wait(self._end_copy_in_gen(exc))
else:
self.connection.wait(self._end_copy_out_gen(exc))
@abstractmethod
def write(self, data: Buffer) -> None:
+ """
+ Write some data to destination.
+ """
...
- @abstractmethod
- def write_end(self, data: Buffer) -> None:
- ...
+ def finish(self) -> None:
+ """
+ Called when write operations are finished.
+ """
+ pass
-class ConnectionWriter(Writer[ConnectionType]):
- def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"):
- self.connection = cursor.connection
+class ConnectionWriter(Writer["Connection[Any]"]):
+ def __init__(self, connection: "Connection[Any]"):
+ self.connection = connection
self._pgconn = self.connection.pgconn
+ def write(self, data: Buffer) -> None:
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ self.connection.wait(copy_to(self._pgconn, data))
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ self.connection.wait(
+ copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
+ )
+
-class QueueWriter(ConnectionWriter["Connection[Any]"]):
+class QueueWriter(ConnectionWriter):
"""
A writer using a buffer to queue data to write.
on the connection.
"""
- def __init__(self, cursor: "Cursor[Any]"):
- super().__init__(cursor)
+ def __init__(self, connection: "Connection[Any]"):
+ super().__init__(connection)
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE)
self._worker: Optional[threading.Thread] = None
for i in range(0, len(data), MAX_BUFFER_SIZE):
self._queue.put(data[i : i + MAX_BUFFER_SIZE])
- def write_end(self, data: Buffer) -> None:
- if data:
- self.write(data)
+ def finish(self) -> None:
self._queue.put(b"")
if self._worker:
copy.write("a,b")
+@pytest.mark.parametrize(
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+def test_connection_writer(conn, format, buffer):
+ cur = conn.cursor()
+ writer = psycopg.copy.ConnectionWriter(conn)
+
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(
+ f"copy copy_in from stdin (format {format.name})", writer=writer
+ ) as copy:
+ assert copy.writer is writer
+ copy.write(globals()[buffer])
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
@pytest.mark.slow
@pytest.mark.parametrize(
"fmt, set_types",