ACTIVE = pq.TransactionStatus.ACTIVE
+# Size of data to accumulate before sending it down the network. We fill a
+# buffer this size field by field, and when it passes the threshold size
+# we ship it, so it may end up being bigger than this.
+BUFFER_SIZE = 32 * 1024
+
+# Maximum data size we want to queue to send to the libpq copy. Sending a
+# buffer too big to be handled can cause an infinite loop in the libpq
+# (#255) so we want to split it in more digestable chunks.
+MAX_BUFFER_SIZE = 4 * BUFFER_SIZE
+# Note: making this buffer too large, e.g.
+# MAX_BUFFER_SIZE = 1024 * 1024
+# makes operations *way* slower! Probably triggering some quadraticity
+# in the libpq memory management and data sending.
+
+# Max size of the write queue of buffers. More than that copy will block
+# Each buffer should be around BUFFER_SIZE size.
+QUEUE_SIZE = 1024
+
class BaseCopy(Generic[ConnectionType]):
"""
_Self = TypeVar("_Self", bound="BaseCopy[ConnectionType]")
- # Max size of the write queue of buffers. More than that copy will block
- # Each buffer should be around BUFFER_SIZE size.
- QUEUE_SIZE = 1024
-
- # Size of data to accumulate before sending it down the network. We fill a
- # buffer this size field by field, and when it passes the threshold size
- # we ship it, so it may end up being bigger than this.
- BUFFER_SIZE = 32 * 1024
-
- # Maximum data size we want to queue to send to the libpq copy. Sending a
- # buffer too big to be handled can cause an infinite loop in the libpq
- # (#255) so we want to split it in more digestable chunks.
- MAX_BUFFER_SIZE = 4 * BUFFER_SIZE
- # Note: making this buffer too large, e.g.
- # MAX_BUFFER_SIZE = 1024 * 1024
- # makes operations *way* slower! Probably triggering some quadraticity
- # in the libpq memory management and data sending.
-
formatter: "Formatter"
+ writer: "Writer[ConnectionType]"
def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"):
self.cursor = cursor
__module__ = "psycopg"
- def __init__(self, cursor: "Cursor[Any]"):
+ def __init__(
+ self, cursor: "Cursor[Any]", writer: Optional["Writer[Connection[Any]]"] = None
+ ):
super().__init__(cursor)
- self._queue: queue.Queue[bytes] = queue.Queue(maxsize=self.QUEUE_SIZE)
- self._worker: Optional[threading.Thread] = None
- self._worker_error: Optional[BaseException] = None
+ if not writer:
+ writer = QueueWriter(cursor)
+
+ self.writer = writer
+ self._write = writer.write
+ self._write_end = writer.write_end
def __enter__(self: BaseCopy._Self) -> BaseCopy._Self:
self._enter()
using the `Copy` object outside a block.
"""
if self._pgresult.status == COPY_IN:
- self._write_end()
+ data = self.formatter.end()
+ self._write_end(data)
self.connection.wait(self._end_copy_in_gen(exc))
else:
self.connection.wait(self._end_copy_out_gen(exc))
- # Concurrent copy support
-
- def worker(self) -> None:
- """Push data to the server when available from the copy queue.
-
- Terminate reading when the queue receives a false-y value, or in case
- of error.
-
- The function is designed to be run in a separate thread.
- """
- try:
- while True:
- data = self._queue.get(block=True, timeout=24 * 60 * 60)
- if not data:
- break
- self.connection.wait(copy_to(self._pgconn, data))
- except BaseException as ex:
- # Propagate the error to the main thread.
- self._worker_error = ex
-
- def _write(self, data: Buffer) -> None:
- if not self._worker:
- # warning: reference loop, broken by _write_end
- self._worker = threading.Thread(target=self.worker)
- self._worker.daemon = True
- self._worker.start()
-
- # If the worker thread raies an exception, re-raise it to the caller.
- if self._worker_error:
- raise self._worker_error
-
- if len(data) <= self.MAX_BUFFER_SIZE:
- # Most used path: we don't need to split the buffer in smaller
- # bits, so don't make a copy.
- self._queue.put(data)
- else:
- # Copy a buffer too large in chunks to avoid causing a memory
- # error in the libpq, which may cause an infinite loop (#255).
- for i in range(0, len(data), self.MAX_BUFFER_SIZE):
- self._queue.put(data[i : i + self.MAX_BUFFER_SIZE])
-
- def _write_end(self) -> None:
- data = self.formatter.end()
- if data:
- self._write(data)
- self._queue.put(b"")
-
- if self._worker:
- self._worker.join()
- self._worker = None # break the loop
-
- # Check if the worker thread raised any exception before terminating.
- if self._worker_error:
- raise self._worker_error
-
class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
"""Manage an asynchronous :sql:`COPY` operation."""
def __init__(self, cursor: "AsyncCursor[Any]"):
super().__init__(cursor)
- self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=self.QUEUE_SIZE)
+ self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=QUEUE_SIZE)
self._worker: Optional[asyncio.Future[None]] = None
async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self:
if not self._worker:
self._worker = create_task(self.worker())
- if len(data) <= self.MAX_BUFFER_SIZE:
+ if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
await self._queue.put(data)
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
- for i in range(0, len(data), self.MAX_BUFFER_SIZE):
- await self._queue.put(data[i : i + self.MAX_BUFFER_SIZE])
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
async def _write_end(self) -> None:
data = self.formatter.end()
self._worker = None # break reference loops if any
+class Writer(Generic[ConnectionType], ABC):
+ """
+ A class to write copy data somewhere.
+ """
+
+ @abstractmethod
+ def write(self, data: Buffer) -> None:
+ ...
+
+ @abstractmethod
+ def write_end(self, data: Buffer) -> None:
+ ...
+
+
+class ConnectionWriter(Writer[ConnectionType]):
+ def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"):
+ self.connection = cursor.connection
+ self._pgconn = self.connection.pgconn
+
+
+class QueueWriter(ConnectionWriter["Connection[Any]"]):
+ """
+ A writer using a buffer to queue data to write.
+
+ `write()` returns immediately, so that the main thread can be CPU-bound
+ formatting messages, while a worker thread can be IO-bound waiting to write
+ on the connection.
+ """
+
+ def __init__(self, cursor: "Cursor[Any]"):
+ super().__init__(cursor)
+
+ self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE)
+ self._worker: Optional[threading.Thread] = None
+ self._worker_error: Optional[BaseException] = None
+
+ def worker(self) -> None:
+ """Push data to the server when available from the copy queue.
+
+ Terminate reading when the queue receives a false-y value, or in case
+ of error.
+
+ The function is designed to be run in a separate thread.
+ """
+ try:
+ while True:
+ data = self._queue.get(block=True, timeout=24 * 60 * 60)
+ if not data:
+ break
+ self.connection.wait(copy_to(self._pgconn, data))
+ except BaseException as ex:
+ # Propagate the error to the main thread.
+ self._worker_error = ex
+
+ def write(self, data: Buffer) -> None:
+ if not self._worker:
+ # warning: reference loop, broken by _write_end
+ self._worker = threading.Thread(target=self.worker)
+ self._worker.daemon = True
+ self._worker.start()
+
+ # If the worker thread raies an exception, re-raise it to the caller.
+ if self._worker_error:
+ raise self._worker_error
+
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ self._queue.put(data)
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ self._queue.put(data[i : i + MAX_BUFFER_SIZE])
+
+ def write_end(self, data: Buffer) -> None:
+ if data:
+ self.write(data)
+ self._queue.put(b"")
+
+ if self._worker:
+ self._worker.join()
+ self._worker = None # break the loop
+
+ # Check if the worker thread raised any exception before terminating.
+ if self._worker_error:
+ raise self._worker_error
+
+
class Formatter(ABC):
"""
A class which understand a copy format (text, binary).
"""
format: pq.Format
- BUFFER_SIZE = BaseCopy.BUFFER_SIZE
def __init__(self, transformer: Transformer):
self.transformer = transformer
self._row_mode = True
format_row_text(row, self.transformer, self._write_buffer)
- if len(self._write_buffer) > self.BUFFER_SIZE:
+ if len(self._write_buffer) > BUFFER_SIZE:
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
else:
self._signature_sent = True
format_row_binary(row, self.transformer, self._write_buffer)
- if len(self._write_buffer) > self.BUFFER_SIZE:
+ if len(self._write_buffer) > BUFFER_SIZE:
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
else: