# Copyright (C) 2020 The Psycopg Team
import re
+import queue
import struct
+import asyncio
+import threading
from types import TracebackType
from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union
from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple
if TYPE_CHECKING:
from .pq.proto import PGresult
- from .cursor import BaseCursor # noqa: F401
+ from .cursor import BaseCursor, Cursor, AsyncCursor
from .connection import Connection, AsyncConnection # noqa: F401
TEXT = pq.Format.TEXT
class BaseCopy(Generic[ConnectionType]):
+ """
+ Base implementation for copy user interface
+
+ Two subclasses expose real methods with the sync/async differences.
+
+ The difference between the text and binary format is managed by two
+ different `Formatter` subclasses.
+
+ While the interface doesn't dictate it, both subclasses are implemented
+ with a worker to perform I/O related work, consuming the data provided in
+ the correct format from a queue, while the main thread is concerned with
+ formatting the data in copy format and adding it to the queue.
+ """
+
+ # Size of data to accumulate before sending it down the network
+ BUFFER_SIZE = 32 * 1024
+
+ # Max size of the write queue of buffers. More than that copy will block
+ QUEUE_SIZE = 1024
+
def __init__(self, cursor: "BaseCursor[ConnectionType]"):
self.cursor = cursor
self.connection = cursor.connection
self._signature_sent = False
self._row_mode = False # true if the user is using send_row()
self._write_buffer = bytearray()
- self._write_buffer_size = 32 * 1024
self._finished = False
if self.format == TEXT:
return self._parse_row(data, self.transformer)
- def _write_gen(self, buffer: Union[str, bytes]) -> PQGen[None]:
- # if write() was called, assume the header was sent together with the
- # first block of data.
+ def _format_write(self, buffer: Union[str, bytes]) -> bytes:
+ data = self._ensure_bytes(buffer)
self._signature_sent = True
- yield from copy_to(self._pgconn, self._ensure_bytes(buffer))
+ return data
- def _write_row_gen(self, row: Sequence[Any]) -> PQGen[None]:
+ def _format_write_row(self, row: Sequence[Any]) -> bytes:
# Note down that we are writing in row mode: it means we will have
# to take care of the end-of-copy marker too
self._row_mode = True
self._signature_sent = True
self._format_row(row, self.transformer, self._write_buffer)
- if len(self._write_buffer) > self._write_buffer_size:
- yield from copy_to(self._pgconn, self._write_buffer)
- self._write_buffer.clear()
-
- def _finish_gen(self, error: str = "") -> PQGen[None]:
- if error:
- berr = error.encode(self.connection.client_encoding, "replace")
- res = yield from copy_end(self._pgconn, berr)
+ if len(self._write_buffer) > self.BUFFER_SIZE:
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+ else:
+ return b""
+
+ def _format_end(self) -> bytes:
+ if self.format == BINARY:
+ # If we have sent no data we need to send the signature
+ # and the trailer
+ if not self._signature_sent:
+ self._write_buffer += _binary_signature
+ self._write_buffer += _binary_trailer
+ elif self._row_mode:
+
+ # if we have sent data already, we have sent the signature
+ # too (either with the first row, or we assume that in
+ # block mode the signature is included).
+ # Write the trailer only if we are sending rows (with the
+ # assumption that who is copying binary data is sending the
+ # whole format).
+ self._write_buffer += _binary_trailer
+
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+
+ def _end_copy_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
+ bmsg: Optional[bytes]
+ if exc:
+ msg = f"error from Python: {type(exc).__qualname__} - {exc}"
+ bmsg = msg.encode(self.connection.client_encoding, "replace")
else:
- if self.format == BINARY:
- # If we have sent no data we need to send the signature
- # and the trailer
- if not self._signature_sent:
- self._write_buffer += _binary_signature
- self._write_buffer += _binary_trailer
- elif self._row_mode:
-
- # if we have sent data already, we have sent the signature
- # too (either with the first row, or we assume that in
- # block mode the signature is included).
- # Write the trailer only if we are sending rows (with the
- # assumption that who is copying binary data is sending the
- # whole format).
- self._write_buffer += _binary_trailer
-
- if self._write_buffer:
- yield from copy_to(self._pgconn, self._write_buffer)
- self._write_buffer.clear()
- res = yield from copy_end(self._pgconn, None)
+ bmsg = None
+
+ res = yield from copy_end(self._pgconn, bmsg)
nrows = res.command_tuples
self.cursor._rowcount = nrows if nrows is not None else -1
self._finished = True
- def _exit_gen(
- self,
- exc_type: Optional[Type[BaseException]],
- exc_val: Optional[BaseException],
- ) -> PQGen[None]:
- # no-op in COPY TO
- if self._pgresult.status == ExecStatus.COPY_OUT:
- return
-
- # In case of error in Python let's quit it here
- if exc_type:
- yield from self._finish_gen(
- f"error from Python: {exc_type.__qualname__} - {exc_val}"
- )
- else:
- yield from self._finish_gen()
-
# Support methods
def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
__module__ = "psycopg3"
+ def __init__(self, cursor: "Cursor"):
+ super().__init__(cursor)
+ self._queue: queue.Queue[Optional[bytes]] = queue.Queue(
+ maxsize=self.QUEUE_SIZE
+ )
+ self._worker: Optional[threading.Thread] = None
+
+ def __enter__(self) -> "Copy":
+ self._check_reuse()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.finish(exc_val)
+
+ # End user sync interface
+
+ def __iter__(self) -> Iterator[memoryview]:
+ """Implement block-by-block iteration on :sql:`COPY TO`."""
+ while True:
+ data = self.read()
+ if not data:
+ break
+ yield data
+
def read(self) -> memoryview:
"""
Read an unparsed row after a :sql:`COPY TO` operation.
If the :sql:`COPY` is in binary format *buffer* must be `!bytes`. In
text mode it can be either `!bytes` or `!str`.
"""
- self.connection.wait(self._write_gen(buffer))
+ data = self._format_write(buffer)
+ self._write(data)
def write_row(self, row: Sequence[Any]) -> None:
"""Write a record to a table after a :sql:`COPY FROM` operation."""
- self.connection.wait(self._write_row_gen(row))
+ data = self._format_write_row(row)
+ self._write(data)
- def _finish(self, error: str = "") -> None:
- """Terminate a :sql:`COPY FROM` operation."""
- self.connection.wait(self._finish_gen(error))
+ def finish(self, exc: Optional[BaseException]) -> None:
+ """Terminate the copy operation and free the resources allocated.
- def __enter__(self) -> "Copy":
+ You shouldn't need to call this function yourself: it is usually called
+ by exit. It is available if, despite what is documented, you end up
+ using the `Copy` object outside a block.
+ """
+ # no-op in COPY TO
+ if self._pgresult.status == ExecStatus.COPY_OUT:
+ return
+
+ self._write_end()
+ self.connection.wait(self._end_copy_gen(exc))
+
+ # Concurrent copy support
+
+ def worker(self) -> None:
+ """Push data to the server when available from the copy queue.
+
+ Terminate reading when the queue receives a None.
+
+ The function is designed to be run in a separate thread.
+ """
+ while 1:
+ data = self._queue.get(block=True, timeout=24 * 60 * 60)
+ if not data:
+ break
+ self.connection.wait(copy_to(self._pgconn, data))
+
+ def _write(self, data: bytes) -> None:
+ if not data:
+ return
+
+ if not self._worker:
+ # warning: reference loop, broken by _write_end
+ self._worker = threading.Thread(target=self.worker)
+ self._worker.daemon = True
+ self._worker.start()
+
+ self._queue.put(data)
+
+ def _write_end(self) -> None:
+ data = self._format_end()
+ self._write(data)
+ self._queue.put(None)
+
+ if self._worker:
+ self._worker.join()
+ self._worker = None # break the loop
+
+
+class AsyncCopy(BaseCopy["AsyncConnection"]):
+ """Manage an asynchronous :sql:`COPY` operation."""
+
+ __module__ = "psycopg3"
+
+ def __init__(self, cursor: "AsyncCursor"):
+ super().__init__(cursor)
+ self._queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue(
+ maxsize=self.QUEUE_SIZE
+ )
+ self._worker: Optional[asyncio.Future[None]] = None
+
+ async def __aenter__(self) -> "AsyncCopy":
self._check_reuse()
return self
- def __exit__(
+ async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
- self.connection.wait(self._exit_gen(exc_type, exc_val))
+ await self.finish(exc_val)
- def __iter__(self) -> Iterator[memoryview]:
+ async def __aiter__(self) -> AsyncIterator[memoryview]:
while True:
- data = self.read()
+ data = await self.read()
if not data:
break
yield data
-
-class AsyncCopy(BaseCopy["AsyncConnection"]):
- """Manage an asynchronous :sql:`COPY` operation."""
-
- __module__ = "psycopg3"
-
async def read(self) -> memoryview:
return await self.connection.wait(self._read_gen())
return await self.connection.wait(self._read_row_gen())
async def write(self, buffer: Union[str, bytes]) -> None:
- await self.connection.wait(self._write_gen(buffer))
+ data = self._format_write(buffer)
+ await self._write(data)
async def write_row(self, row: Sequence[Any]) -> None:
- await self.connection.wait(self._write_row_gen(row))
+ data = self._format_write_row(row)
+ await self._write(data)
- async def _finish(self, error: str = "") -> None:
- await self.connection.wait(self._finish_gen(error))
+ async def finish(self, exc: Optional[BaseException]) -> None:
+ # no-op in COPY TO
+ if self._pgresult.status == ExecStatus.COPY_OUT:
+ return
- async def __aenter__(self) -> "AsyncCopy":
- self._check_reuse()
- return self
+ await self._write_end()
+ await self.connection.wait(self._end_copy_gen(exc))
- async def __aexit__(
- self,
- exc_type: Optional[Type[BaseException]],
- exc_val: Optional[BaseException],
- exc_tb: Optional[TracebackType],
- ) -> None:
- await self.connection.wait(self._exit_gen(exc_type, exc_val))
+ # Concurrent copy support
- async def __aiter__(self) -> AsyncIterator[memoryview]:
- while True:
- data = await self.read()
+ async def worker(self) -> None:
+ """Push data to the server when available from the copy queue.
+
+ Terminate reading when the queue receives a None.
+
+ The function is designed to be run in a separate thread.
+ """
+ while 1:
+ data = await self._queue.get()
if not data:
break
- yield data
+ await self.connection.wait(copy_to(self._pgconn, data))
+
+ async def _write(self, data: bytes) -> None:
+ if not data:
+ return
+
+ if not self._worker:
+ # TODO: can be asyncio.create_task once Python 3.6 is dropped
+ self._worker = asyncio.ensure_future(self.worker())
+
+ await self._queue.put(data)
+
+ async def _write_end(self) -> None:
+ data = self._format_end()
+ await self._write(data)
+ await self._queue.put(None)
+
+ if self._worker:
+ await asyncio.gather(self._worker)
+ self._worker = None # break reference loops if any
def _format_row_text(