_Self = TypeVar("_Self", bound="BaseCopy[ConnectionType]")
formatter: "Formatter"
- writer: "Writer[ConnectionType]"
def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"):
self.cursor = cursor
__module__ = "psycopg"
- def __init__(
- self,
- cursor: "Cursor[Any]",
- *,
- writer: Optional["Writer[Connection[Any]]"] = None,
- ):
+ writer: "Writer"
+
+ def __init__(self, cursor: "Cursor[Any]", *, writer: Optional["Writer"] = None):
super().__init__(cursor)
if not writer:
writer = QueueWriter(cursor.connection)
self.connection.wait(self._end_copy_out_gen(exc))
+class Writer(ABC):
+ """
+ A class to write copy data somewhere.
+ """
+
+ @abstractmethod
+ def write(self, data: Buffer) -> None:
+ """
+ Write some data to destination.
+ """
+ ...
+
+ def finish(self) -> None:
+ """
+ Called when write operations are finished.
+ """
+ pass
+
+
+class ConnectionWriter(Writer):
+ def __init__(self, connection: "Connection[Any]"):
+ self.connection = connection
+ self._pgconn = self.connection.pgconn
+
+ def write(self, data: Buffer) -> None:
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ self.connection.wait(copy_to(self._pgconn, data))
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ self.connection.wait(
+ copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
+ )
+
+
+class QueueWriter(ConnectionWriter):
+ """
+ A writer using a buffer to queue data to write.
+
+ `write()` returns immediately, so that the main thread can be CPU-bound
+ formatting messages, while a worker thread can be IO-bound waiting to write
+ on the connection.
+ """
+
+ def __init__(self, connection: "Connection[Any]"):
+ super().__init__(connection)
+
+ self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE)
+ self._worker: Optional[threading.Thread] = None
+ self._worker_error: Optional[BaseException] = None
+
+ def worker(self) -> None:
+ """Push data to the server when available from the copy queue.
+
+ Terminate reading when the queue receives a false-y value, or in case
+ of error.
+
+ The function is designed to be run in a separate thread.
+ """
+ try:
+ while True:
+ data = self._queue.get(block=True, timeout=24 * 60 * 60)
+ if not data:
+ break
+ self.connection.wait(copy_to(self._pgconn, data))
+ except BaseException as ex:
+ # Propagate the error to the main thread.
+ self._worker_error = ex
+
+ def write(self, data: Buffer) -> None:
+ if not self._worker:
+ # warning: reference loop, broken by _write_end
+ self._worker = threading.Thread(target=self.worker)
+ self._worker.daemon = True
+ self._worker.start()
+
+ # If the worker thread raies an exception, re-raise it to the caller.
+ if self._worker_error:
+ raise self._worker_error
+
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ self._queue.put(data)
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ self._queue.put(data[i : i + MAX_BUFFER_SIZE])
+
+ def finish(self) -> None:
+ self._queue.put(b"")
+
+ if self._worker:
+ self._worker.join()
+ self._worker = None # break the loop
+
+ # Check if the worker thread raised any exception before terminating.
+ if self._worker_error:
+ raise self._worker_error
+
+
class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
"""Manage an asynchronous :sql:`COPY` operation."""
__module__ = "psycopg"
- def __init__(self, cursor: "AsyncCursor[Any]"):
+ writer: "AsyncWriter"
+
+ def __init__(
+ self, cursor: "AsyncCursor[Any]", *, writer: Optional["AsyncWriter"] = None
+ ):
super().__init__(cursor)
- self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=QUEUE_SIZE)
- self._worker: Optional[asyncio.Future[None]] = None
+
+ if not writer:
+ writer = AsyncQueueWriter(cursor.connection)
+
+ self.writer = writer
+ self._write = writer.write
async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self:
self._enter()
async def finish(self, exc: Optional[BaseException]) -> None:
if self._pgresult.status == COPY_IN:
- await self._write_end()
+ data = self.formatter.end()
+ if data:
+ await self._write(data)
+ await self.writer.finish()
await self.connection.wait(self._end_copy_in_gen(exc))
else:
await self.connection.wait(self._end_copy_out_gen(exc))
- # Concurrent copy support
- async def worker(self) -> None:
- """Push data to the server when available from the copy queue.
-
- Terminate reading when the queue receives a false-y value.
-
- The function is designed to be run in a separate thread.
- """
- while True:
- data = await self._queue.get()
- if not data:
- break
- await self.connection.wait(copy_to(self._pgconn, data))
-
- async def _write(self, data: Buffer) -> None:
- if not self._worker:
- self._worker = create_task(self.worker())
-
- if len(data) <= MAX_BUFFER_SIZE:
- # Most used path: we don't need to split the buffer in smaller
- # bits, so don't make a copy.
- await self._queue.put(data)
- else:
- # Copy a buffer too large in chunks to avoid causing a memory
- # error in the libpq, which may cause an infinite loop (#255).
- for i in range(0, len(data), MAX_BUFFER_SIZE):
- await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
-
- async def _write_end(self) -> None:
- data = self.formatter.end()
- if data:
- await self._write(data)
- await self._queue.put(b"")
-
- if self._worker:
- await asyncio.gather(self._worker)
- self._worker = None # break reference loops if any
-
-
-class Writer(Generic[ConnectionType], ABC):
+class AsyncWriter(ABC):
"""
A class to write copy data somewhere.
"""
@abstractmethod
- def write(self, data: Buffer) -> None:
+ async def write(self, data: Buffer) -> None:
"""
Write some data to destination.
"""
...
- def finish(self) -> None:
+ async def finish(self) -> None:
"""
Called when write operations are finished.
"""
pass
-class ConnectionWriter(Writer["Connection[Any]"]):
- def __init__(self, connection: "Connection[Any]"):
+class AsyncConnectionWriter(AsyncWriter):
+ def __init__(self, connection: "AsyncConnection[Any]"):
self.connection = connection
self._pgconn = self.connection.pgconn
- def write(self, data: Buffer) -> None:
+ async def write(self, data: Buffer) -> None:
if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
- self.connection.wait(copy_to(self._pgconn, data))
+ await self.connection.wait(copy_to(self._pgconn, data))
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
for i in range(0, len(data), MAX_BUFFER_SIZE):
- self.connection.wait(
+ await self.connection.wait(
copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
)
-class QueueWriter(ConnectionWriter):
+class AsyncQueueWriter(AsyncConnectionWriter):
"""
A writer using a buffer to queue data to write.
on the connection.
"""
- def __init__(self, connection: "Connection[Any]"):
+ def __init__(self, connection: "AsyncConnection[Any]"):
super().__init__(connection)
- self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE)
- self._worker: Optional[threading.Thread] = None
- self._worker_error: Optional[BaseException] = None
+ self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=QUEUE_SIZE)
+ self._worker: Optional[asyncio.Future[None]] = None
- def worker(self) -> None:
+ async def worker(self) -> None:
"""Push data to the server when available from the copy queue.
- Terminate reading when the queue receives a false-y value, or in case
- of error.
+ Terminate reading when the queue receives a false-y value.
- The function is designed to be run in a separate thread.
+ The function is designed to be run in a separate task.
"""
- try:
- while True:
- data = self._queue.get(block=True, timeout=24 * 60 * 60)
- if not data:
- break
- self.connection.wait(copy_to(self._pgconn, data))
- except BaseException as ex:
- # Propagate the error to the main thread.
- self._worker_error = ex
+ while True:
+ data = await self._queue.get()
+ if not data:
+ break
+ await self.connection.wait(copy_to(self._pgconn, data))
- def write(self, data: Buffer) -> None:
+ async def write(self, data: Buffer) -> None:
if not self._worker:
- # warning: reference loop, broken by _write_end
- self._worker = threading.Thread(target=self.worker)
- self._worker.daemon = True
- self._worker.start()
-
- # If the worker thread raies an exception, re-raise it to the caller.
- if self._worker_error:
- raise self._worker_error
+ self._worker = create_task(self.worker())
if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
- self._queue.put(data)
+ await self._queue.put(data)
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
for i in range(0, len(data), MAX_BUFFER_SIZE):
- self._queue.put(data[i : i + MAX_BUFFER_SIZE])
+ await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
- def finish(self) -> None:
- self._queue.put(b"")
+ async def finish(self) -> None:
+ await self._queue.put(b"")
if self._worker:
- self._worker.join()
- self._worker = None # break the loop
-
- # Check if the worker thread raised any exception before terminating.
- if self._worker_error:
- raise self._worker_error
+ await asyncio.gather(self._worker)
+ self._worker = None # break reference loops if any
class Formatter(ABC):