From: Daniele Varrazzo Date: Fri, 15 Jan 2021 17:09:45 +0000 (+0100) Subject: Added copy Format helpers to handle text/binary differences X-Git-Tag: 3.0.dev0~154 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=35d672dc8b52ec671f238abac1a4146a3e1479e4;p=thirdparty%2Fpsycopg.git Added copy Format helpers to handle text/binary differences --- diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index ac0850f7e..d4b93869f 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -9,6 +9,7 @@ import queue import struct import asyncio import threading +from abc import ABC, abstractmethod from types import TracebackType from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple @@ -44,42 +45,39 @@ class BaseCopy(Generic[ConnectionType]): formatting the data in copy format and adding it to the queue. """ - # Size of data to accumulate before sending it down the network - BUFFER_SIZE = 32 * 1024 - # Max size of the write queue of buffers. More than that copy will block + # Each buffer around Formatter.BUFFER_SIZE size QUEUE_SIZE = 1024 + formatter: "Formatter" + def __init__(self, cursor: "BaseCursor[ConnectionType]"): self.cursor = cursor self.connection = cursor.connection - self.transformer = cursor._transformer self._pgconn = self.connection.pgconn - assert ( - self.transformer.pgresult - ), "The Transformer doesn't have a PGresult set" - self._pgresult: "PGresult" = self.transformer.pgresult - - self.format = pq.Format(self._pgresult.binary_tuples) - self._encoding = self.connection.client_encoding - self._signature_sent = False - self._row_mode = False # true if the user is using send_row() - self._write_buffer = bytearray() - self._finished = False + tx = cursor._transformer + assert tx.pgresult, "The Transformer doesn't have a PGresult set" + self._pgresult: "PGresult" = tx.pgresult - if self.format == TEXT: - self._format_row = format_row_text - self._parse_row = parse_row_text + if self._pgresult.binary_tuples == pq.Format.TEXT: + self.formatter = TextFormatter( + tx, encoding=self.connection.client_encoding + ) else: - self._format_row = format_row_binary - self._parse_row = parse_row_binary + self.formatter = BinaryFormatter(tx) + + self._finished = False def __repr__(self) -> str: cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" info = pq.misc.connection_summary(self._pgconn) return f"<{cls} {info} at 0x{id(self):x}>" + def _enter(self) -> None: + if self._finished: + raise TypeError("copy blocks can be used only once") + def set_types(self, types: Sequence[int]) -> None: """ Set the types expected out of a :sql:`COPY TO` operation. @@ -87,7 +85,9 @@ class BaseCopy(Generic[ConnectionType]): Without setting the types, the data from :sql:`COPY TO` will be returned as unparsed strings or bytes. """ - self.transformer.set_row_types(types, [self.format] * len(types)) + self.formatter.transformer.set_row_types( + types, [self.formatter.format] * len(types) + ) # High level copy protocol generators (state change of the Copy object) @@ -110,62 +110,14 @@ class BaseCopy(Generic[ConnectionType]): if not data: return None - if self.format == BINARY: - if not self._signature_sent: - if data[: len(_binary_signature)] != _binary_signature: - raise e.DataError( - "binary copy doesn't start with the expected signature" - ) - self._signature_sent = True - data = data[len(_binary_signature) :] - - elif data == _binary_trailer: - yield from self._read_gen() - self._finished = True - return None - - return self._parse_row(data, self.transformer) - - def _format_write(self, buffer: Union[str, bytes]) -> bytes: - data = self._ensure_bytes(buffer) - self._signature_sent = True - return data - - def _format_write_row(self, row: Sequence[Any]) -> bytes: - # Note down that we are writing in row mode: it means we will have - # to take care of the end-of-copy marker too - self._row_mode = True - - if self.format == BINARY and not self._signature_sent: - self._write_buffer += _binary_signature - self._signature_sent = True - - self._format_row(row, self.transformer, self._write_buffer) - if len(self._write_buffer) > self.BUFFER_SIZE: - buffer, self._write_buffer = self._write_buffer, bytearray() - return buffer - else: - return b"" + row = self.formatter.parse_row(data) + if row is None: + # Get the final result to finish the copy operation + yield from self._read_gen() + self._finished = True + return None - def _format_end(self) -> bytes: - if self.format == BINARY: - # If we have sent no data we need to send the signature - # and the trailer - if not self._signature_sent: - self._write_buffer += _binary_signature - self._write_buffer += _binary_trailer - elif self._row_mode: - - # if we have sent data already, we have sent the signature - # too (either with the first row, or we assume that in - # block mode the signature is included). - # Write the trailer only if we are sending rows (with the - # assumption that who is copying binary data is sending the - # whole format). - self._write_buffer += _binary_trailer - - buffer, self._write_buffer = self._write_buffer, bytearray() - return buffer + return row def _end_copy_gen(self, exc: Optional[BaseException]) -> PQGen[None]: bmsg: Optional[bytes] @@ -181,26 +133,6 @@ class BaseCopy(Generic[ConnectionType]): self.cursor._rowcount = nrows if nrows is not None else -1 self._finished = True - # Support methods - - def _ensure_bytes(self, data: Union[bytes, str]) -> bytes: - if isinstance(data, bytes): - return data - - elif isinstance(data, str): - if self._pgresult.binary_tuples == BINARY: - raise TypeError( - "cannot copy str data in binary mode: use bytes instead" - ) - return data.encode(self._encoding) - - else: - raise TypeError(f"can't write {type(data).__name__}") - - def _check_reuse(self) -> None: - if self._finished: - raise TypeError("copy blocks can be used only once") - class Copy(BaseCopy["Connection"]): """Manage a :sql:`COPY` operation.""" @@ -215,7 +147,7 @@ class Copy(BaseCopy["Connection"]): self._worker: Optional[threading.Thread] = None def __enter__(self) -> "Copy": - self._check_reuse() + self._enter() return self def __exit__( @@ -275,12 +207,12 @@ class Copy(BaseCopy["Connection"]): If the :sql:`COPY` is in binary format *buffer* must be `!bytes`. In text mode it can be either `!bytes` or `!str`. """ - data = self._format_write(buffer) + data = self.formatter.write(buffer) self._write(data) def write_row(self, row: Sequence[Any]) -> None: """Write a record to a table after a :sql:`COPY FROM` operation.""" - data = self._format_write_row(row) + data = self.formatter.write_row(row) self._write(data) def finish(self, exc: Optional[BaseException]) -> None: @@ -325,7 +257,7 @@ class Copy(BaseCopy["Connection"]): self._queue.put(data) def _write_end(self) -> None: - data = self._format_end() + data = self.formatter.end() self._write(data) self._queue.put(None) @@ -347,7 +279,7 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): self._worker: Optional[asyncio.Future[None]] = None async def __aenter__(self) -> "AsyncCopy": - self._check_reuse() + self._enter() return self async def __aexit__( @@ -379,11 +311,11 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): return await self.connection.wait(self._read_row_gen()) async def write(self, buffer: Union[str, bytes]) -> None: - data = self._format_write(buffer) + data = self.formatter.write(buffer) await self._write(data) async def write_row(self, row: Sequence[Any]) -> None: - data = self._format_write_row(row) + data = self.formatter.write_row(row) await self._write(data) async def finish(self, exc: Optional[BaseException]) -> None: @@ -420,7 +352,7 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): await self._queue.put(data) async def _write_end(self) -> None: - data = self._format_end() + data = self.formatter.end() await self._write(data) await self._queue.put(None) @@ -429,6 +361,159 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): self._worker = None # break reference loops if any +class Formatter(ABC): + """ + A class which understand a copy format (text, binary). + """ + + format: pq.Format + + # Size of data to accumulate before sending it down the network + BUFFER_SIZE = 32 * 1024 + + def __init__(self, transformer: Transformer): + self.transformer = transformer + self._write_buffer = bytearray() + self._row_mode = False # true if the user is using send_row() + + @abstractmethod + def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]: + ... + + @abstractmethod + def write(self, buffer: Union[str, bytes]) -> bytes: + ... + + @abstractmethod + def write_row(self, row: Sequence[Any]) -> bytes: + ... + + @abstractmethod + def end(self) -> bytes: + ... + + +class TextFormatter(Formatter): + + format = pq.Format.TEXT + + def __init__(self, transformer: Transformer, encoding: str = "utf-8"): + super().__init__(transformer) + self._encoding = encoding + + def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]: + if data: + return parse_row_text(data, self.transformer) + else: + return None + + def write(self, buffer: Union[str, bytes]) -> bytes: + data = self._ensure_bytes(buffer) + self._signature_sent = True + return data + + def write_row(self, row: Sequence[Any]) -> bytes: + # Note down that we are writing in row mode: it means we will have + # to take care of the end-of-copy marker too + self._row_mode = True + + format_row_text(row, self.transformer, self._write_buffer) + if len(self._write_buffer) > self.BUFFER_SIZE: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + else: + return b"" + + def end(self) -> bytes: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + + def _ensure_bytes(self, data: Union[bytes, str]) -> bytes: + if isinstance(data, bytes): + return data + + elif isinstance(data, str): + return data.encode(self._encoding) + + else: + raise TypeError(f"can't write {type(data).__name__}") + + +class BinaryFormatter(Formatter): + + format = pq.Format.BINARY + + def __init__(self, transformer: Transformer): + super().__init__(transformer) + self._signature_sent = False + + def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]: + if not self._signature_sent: + if data[: len(_binary_signature)] != _binary_signature: + raise e.DataError( + "binary copy doesn't start with the expected signature" + ) + self._signature_sent = True + data = data[len(_binary_signature) :] + + elif data == _binary_trailer: + return None + + return parse_row_binary(data, self.transformer) + + def write(self, buffer: Union[str, bytes]) -> bytes: + data = self._ensure_bytes(buffer) + self._signature_sent = True + return data + + def write_row(self, row: Sequence[Any]) -> bytes: + # Note down that we are writing in row mode: it means we will have + # to take care of the end-of-copy marker too + self._row_mode = True + + if not self._signature_sent: + self._write_buffer += _binary_signature + self._signature_sent = True + + format_row_binary(row, self.transformer, self._write_buffer) + if len(self._write_buffer) > self.BUFFER_SIZE: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + else: + return b"" + + def end(self) -> bytes: + # If we have sent no data we need to send the signature + # and the trailer + if not self._signature_sent: + self._write_buffer += _binary_signature + self._write_buffer += _binary_trailer + + elif self._row_mode: + # if we have sent data already, we have sent the signature + # too (either with the first row, or we assume that in + # block mode the signature is included). + # Write the trailer only if we are sending rows (with the + # assumption that who is copying binary data is sending the + # whole format). + self._write_buffer += _binary_trailer + + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + + def _ensure_bytes(self, data: Union[bytes, str]) -> bytes: + if isinstance(data, bytes): + return data + + elif isinstance(data, str): + raise TypeError( + "cannot copy str data in binary mode: use bytes instead" + ) + + else: + raise TypeError(f"can't write {type(data).__name__}") + + def _format_row_text( row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None ) -> bytearray: