import struct
import asyncio
import threading
+from abc import ABC, abstractmethod
from types import TracebackType
from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union
from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple
formatting the data in copy format and adding it to the queue.
"""
- # Size of data to accumulate before sending it down the network
- BUFFER_SIZE = 32 * 1024
-
# Max size of the write queue of buffers. More than that copy will block
+ # Each buffer around Formatter.BUFFER_SIZE size
QUEUE_SIZE = 1024
+ formatter: "Formatter"
+
def __init__(self, cursor: "BaseCursor[ConnectionType]"):
self.cursor = cursor
self.connection = cursor.connection
- self.transformer = cursor._transformer
self._pgconn = self.connection.pgconn
- assert (
- self.transformer.pgresult
- ), "The Transformer doesn't have a PGresult set"
- self._pgresult: "PGresult" = self.transformer.pgresult
-
- self.format = pq.Format(self._pgresult.binary_tuples)
- self._encoding = self.connection.client_encoding
- self._signature_sent = False
- self._row_mode = False # true if the user is using send_row()
- self._write_buffer = bytearray()
- self._finished = False
+ tx = cursor._transformer
+ assert tx.pgresult, "The Transformer doesn't have a PGresult set"
+ self._pgresult: "PGresult" = tx.pgresult
- if self.format == TEXT:
- self._format_row = format_row_text
- self._parse_row = parse_row_text
+ if self._pgresult.binary_tuples == pq.Format.TEXT:
+ self.formatter = TextFormatter(
+ tx, encoding=self.connection.client_encoding
+ )
else:
- self._format_row = format_row_binary
- self._parse_row = parse_row_binary
+ self.formatter = BinaryFormatter(tx)
+
+ self._finished = False
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = pq.misc.connection_summary(self._pgconn)
return f"<{cls} {info} at 0x{id(self):x}>"
+ def _enter(self) -> None:
+ if self._finished:
+ raise TypeError("copy blocks can be used only once")
+
def set_types(self, types: Sequence[int]) -> None:
"""
Set the types expected out of a :sql:`COPY TO` operation.
Without setting the types, the data from :sql:`COPY TO` will be
returned as unparsed strings or bytes.
"""
- self.transformer.set_row_types(types, [self.format] * len(types))
+ self.formatter.transformer.set_row_types(
+ types, [self.formatter.format] * len(types)
+ )
# High level copy protocol generators (state change of the Copy object)
if not data:
return None
- if self.format == BINARY:
- if not self._signature_sent:
- if data[: len(_binary_signature)] != _binary_signature:
- raise e.DataError(
- "binary copy doesn't start with the expected signature"
- )
- self._signature_sent = True
- data = data[len(_binary_signature) :]
-
- elif data == _binary_trailer:
- yield from self._read_gen()
- self._finished = True
- return None
-
- return self._parse_row(data, self.transformer)
-
- def _format_write(self, buffer: Union[str, bytes]) -> bytes:
- data = self._ensure_bytes(buffer)
- self._signature_sent = True
- return data
-
- def _format_write_row(self, row: Sequence[Any]) -> bytes:
- # Note down that we are writing in row mode: it means we will have
- # to take care of the end-of-copy marker too
- self._row_mode = True
-
- if self.format == BINARY and not self._signature_sent:
- self._write_buffer += _binary_signature
- self._signature_sent = True
-
- self._format_row(row, self.transformer, self._write_buffer)
- if len(self._write_buffer) > self.BUFFER_SIZE:
- buffer, self._write_buffer = self._write_buffer, bytearray()
- return buffer
- else:
- return b""
+ row = self.formatter.parse_row(data)
+ if row is None:
+ # Get the final result to finish the copy operation
+ yield from self._read_gen()
+ self._finished = True
+ return None
- def _format_end(self) -> bytes:
- if self.format == BINARY:
- # If we have sent no data we need to send the signature
- # and the trailer
- if not self._signature_sent:
- self._write_buffer += _binary_signature
- self._write_buffer += _binary_trailer
- elif self._row_mode:
-
- # if we have sent data already, we have sent the signature
- # too (either with the first row, or we assume that in
- # block mode the signature is included).
- # Write the trailer only if we are sending rows (with the
- # assumption that who is copying binary data is sending the
- # whole format).
- self._write_buffer += _binary_trailer
-
- buffer, self._write_buffer = self._write_buffer, bytearray()
- return buffer
+ return row
def _end_copy_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
bmsg: Optional[bytes]
self.cursor._rowcount = nrows if nrows is not None else -1
self._finished = True
- # Support methods
-
- def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
- if isinstance(data, bytes):
- return data
-
- elif isinstance(data, str):
- if self._pgresult.binary_tuples == BINARY:
- raise TypeError(
- "cannot copy str data in binary mode: use bytes instead"
- )
- return data.encode(self._encoding)
-
- else:
- raise TypeError(f"can't write {type(data).__name__}")
-
- def _check_reuse(self) -> None:
- if self._finished:
- raise TypeError("copy blocks can be used only once")
-
class Copy(BaseCopy["Connection"]):
"""Manage a :sql:`COPY` operation."""
self._worker: Optional[threading.Thread] = None
def __enter__(self) -> "Copy":
- self._check_reuse()
+ self._enter()
return self
def __exit__(
If the :sql:`COPY` is in binary format *buffer* must be `!bytes`. In
text mode it can be either `!bytes` or `!str`.
"""
- data = self._format_write(buffer)
+ data = self.formatter.write(buffer)
self._write(data)
def write_row(self, row: Sequence[Any]) -> None:
"""Write a record to a table after a :sql:`COPY FROM` operation."""
- data = self._format_write_row(row)
+ data = self.formatter.write_row(row)
self._write(data)
def finish(self, exc: Optional[BaseException]) -> None:
self._queue.put(data)
def _write_end(self) -> None:
- data = self._format_end()
+ data = self.formatter.end()
self._write(data)
self._queue.put(None)
self._worker: Optional[asyncio.Future[None]] = None
async def __aenter__(self) -> "AsyncCopy":
- self._check_reuse()
+ self._enter()
return self
async def __aexit__(
return await self.connection.wait(self._read_row_gen())
async def write(self, buffer: Union[str, bytes]) -> None:
- data = self._format_write(buffer)
+ data = self.formatter.write(buffer)
await self._write(data)
async def write_row(self, row: Sequence[Any]) -> None:
- data = self._format_write_row(row)
+ data = self.formatter.write_row(row)
await self._write(data)
async def finish(self, exc: Optional[BaseException]) -> None:
await self._queue.put(data)
async def _write_end(self) -> None:
- data = self._format_end()
+ data = self.formatter.end()
await self._write(data)
await self._queue.put(None)
self._worker = None # break reference loops if any
+class Formatter(ABC):
+ """
+ A class which understand a copy format (text, binary).
+ """
+
+ format: pq.Format
+
+ # Size of data to accumulate before sending it down the network
+ BUFFER_SIZE = 32 * 1024
+
+ def __init__(self, transformer: Transformer):
+ self.transformer = transformer
+ self._write_buffer = bytearray()
+ self._row_mode = False # true if the user is using send_row()
+
+ @abstractmethod
+ def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]:
+ ...
+
+ @abstractmethod
+ def write(self, buffer: Union[str, bytes]) -> bytes:
+ ...
+
+ @abstractmethod
+ def write_row(self, row: Sequence[Any]) -> bytes:
+ ...
+
+ @abstractmethod
+ def end(self) -> bytes:
+ ...
+
+
+class TextFormatter(Formatter):
+
+ format = pq.Format.TEXT
+
+ def __init__(self, transformer: Transformer, encoding: str = "utf-8"):
+ super().__init__(transformer)
+ self._encoding = encoding
+
+ def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]:
+ if data:
+ return parse_row_text(data, self.transformer)
+ else:
+ return None
+
+ def write(self, buffer: Union[str, bytes]) -> bytes:
+ data = self._ensure_bytes(buffer)
+ self._signature_sent = True
+ return data
+
+ def write_row(self, row: Sequence[Any]) -> bytes:
+ # Note down that we are writing in row mode: it means we will have
+ # to take care of the end-of-copy marker too
+ self._row_mode = True
+
+ format_row_text(row, self.transformer, self._write_buffer)
+ if len(self._write_buffer) > self.BUFFER_SIZE:
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+ else:
+ return b""
+
+ def end(self) -> bytes:
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+
+ def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
+ if isinstance(data, bytes):
+ return data
+
+ elif isinstance(data, str):
+ return data.encode(self._encoding)
+
+ else:
+ raise TypeError(f"can't write {type(data).__name__}")
+
+
+class BinaryFormatter(Formatter):
+
+ format = pq.Format.BINARY
+
+ def __init__(self, transformer: Transformer):
+ super().__init__(transformer)
+ self._signature_sent = False
+
+ def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]:
+ if not self._signature_sent:
+ if data[: len(_binary_signature)] != _binary_signature:
+ raise e.DataError(
+ "binary copy doesn't start with the expected signature"
+ )
+ self._signature_sent = True
+ data = data[len(_binary_signature) :]
+
+ elif data == _binary_trailer:
+ return None
+
+ return parse_row_binary(data, self.transformer)
+
+ def write(self, buffer: Union[str, bytes]) -> bytes:
+ data = self._ensure_bytes(buffer)
+ self._signature_sent = True
+ return data
+
+ def write_row(self, row: Sequence[Any]) -> bytes:
+ # Note down that we are writing in row mode: it means we will have
+ # to take care of the end-of-copy marker too
+ self._row_mode = True
+
+ if not self._signature_sent:
+ self._write_buffer += _binary_signature
+ self._signature_sent = True
+
+ format_row_binary(row, self.transformer, self._write_buffer)
+ if len(self._write_buffer) > self.BUFFER_SIZE:
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+ else:
+ return b""
+
+ def end(self) -> bytes:
+ # If we have sent no data we need to send the signature
+ # and the trailer
+ if not self._signature_sent:
+ self._write_buffer += _binary_signature
+ self._write_buffer += _binary_trailer
+
+ elif self._row_mode:
+ # if we have sent data already, we have sent the signature
+ # too (either with the first row, or we assume that in
+ # block mode the signature is included).
+ # Write the trailer only if we are sending rows (with the
+ # assumption that who is copying binary data is sending the
+ # whole format).
+ self._write_buffer += _binary_trailer
+
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+
+ def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
+ if isinstance(data, bytes):
+ return data
+
+ elif isinstance(data, str):
+ raise TypeError(
+ "cannot copy str data in binary mode: use bytes instead"
+ )
+
+ else:
+ raise TypeError(f"can't write {type(data).__name__}")
+
+
def _format_row_text(
row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
) -> bytearray: