]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added copy Format helpers to handle text/binary differences
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 15 Jan 2021 17:09:45 +0000 (18:09 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 16 Jan 2021 01:06:56 +0000 (02:06 +0100)
psycopg3/psycopg3/copy.py

index ac0850f7e1ca945df140b12faf9ce155a6b1e1e9..d4b93869f5d02096de6386ef5cb2961af6a21096 100644 (file)
@@ -9,6 +9,7 @@ import queue
 import struct
 import asyncio
 import threading
+from abc import ABC, abstractmethod
 from types import TracebackType
 from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union
 from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple
@@ -44,42 +45,39 @@ class BaseCopy(Generic[ConnectionType]):
     formatting the data in copy format and adding it to the queue.
     """
 
-    # Size of data to accumulate before sending it down the network
-    BUFFER_SIZE = 32 * 1024
-
     # Max size of the write queue of buffers. More than that copy will block
+    # Each buffer around Formatter.BUFFER_SIZE size
     QUEUE_SIZE = 1024
 
+    formatter: "Formatter"
+
     def __init__(self, cursor: "BaseCursor[ConnectionType]"):
         self.cursor = cursor
         self.connection = cursor.connection
-        self.transformer = cursor._transformer
         self._pgconn = self.connection.pgconn
 
-        assert (
-            self.transformer.pgresult
-        ), "The Transformer doesn't have a PGresult set"
-        self._pgresult: "PGresult" = self.transformer.pgresult
-
-        self.format = pq.Format(self._pgresult.binary_tuples)
-        self._encoding = self.connection.client_encoding
-        self._signature_sent = False
-        self._row_mode = False  # true if the user is using send_row()
-        self._write_buffer = bytearray()
-        self._finished = False
+        tx = cursor._transformer
+        assert tx.pgresult, "The Transformer doesn't have a PGresult set"
+        self._pgresult: "PGresult" = tx.pgresult
 
-        if self.format == TEXT:
-            self._format_row = format_row_text
-            self._parse_row = parse_row_text
+        if self._pgresult.binary_tuples == pq.Format.TEXT:
+            self.formatter = TextFormatter(
+                tx, encoding=self.connection.client_encoding
+            )
         else:
-            self._format_row = format_row_binary
-            self._parse_row = parse_row_binary
+            self.formatter = BinaryFormatter(tx)
+
+        self._finished = False
 
     def __repr__(self) -> str:
         cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
         info = pq.misc.connection_summary(self._pgconn)
         return f"<{cls} {info} at 0x{id(self):x}>"
 
+    def _enter(self) -> None:
+        if self._finished:
+            raise TypeError("copy blocks can be used only once")
+
     def set_types(self, types: Sequence[int]) -> None:
         """
         Set the types expected out of a :sql:`COPY TO` operation.
@@ -87,7 +85,9 @@ class BaseCopy(Generic[ConnectionType]):
         Without setting the types, the data from :sql:`COPY TO` will be
         returned as unparsed strings or bytes.
         """
-        self.transformer.set_row_types(types, [self.format] * len(types))
+        self.formatter.transformer.set_row_types(
+            types, [self.formatter.format] * len(types)
+        )
 
     # High level copy protocol generators (state change of the Copy object)
 
@@ -110,62 +110,14 @@ class BaseCopy(Generic[ConnectionType]):
         if not data:
             return None
 
-        if self.format == BINARY:
-            if not self._signature_sent:
-                if data[: len(_binary_signature)] != _binary_signature:
-                    raise e.DataError(
-                        "binary copy doesn't start with the expected signature"
-                    )
-                self._signature_sent = True
-                data = data[len(_binary_signature) :]
-
-            elif data == _binary_trailer:
-                yield from self._read_gen()
-                self._finished = True
-                return None
-
-        return self._parse_row(data, self.transformer)
-
-    def _format_write(self, buffer: Union[str, bytes]) -> bytes:
-        data = self._ensure_bytes(buffer)
-        self._signature_sent = True
-        return data
-
-    def _format_write_row(self, row: Sequence[Any]) -> bytes:
-        # Note down that we are writing in row mode: it means we will have
-        # to take care of the end-of-copy marker too
-        self._row_mode = True
-
-        if self.format == BINARY and not self._signature_sent:
-            self._write_buffer += _binary_signature
-            self._signature_sent = True
-
-        self._format_row(row, self.transformer, self._write_buffer)
-        if len(self._write_buffer) > self.BUFFER_SIZE:
-            buffer, self._write_buffer = self._write_buffer, bytearray()
-            return buffer
-        else:
-            return b""
+        row = self.formatter.parse_row(data)
+        if row is None:
+            # Get the final result to finish the copy operation
+            yield from self._read_gen()
+            self._finished = True
+            return None
 
-    def _format_end(self) -> bytes:
-        if self.format == BINARY:
-            # If we have sent no data we need to send the signature
-            # and the trailer
-            if not self._signature_sent:
-                self._write_buffer += _binary_signature
-                self._write_buffer += _binary_trailer
-            elif self._row_mode:
-
-                # if we have sent data already, we have sent the signature
-                # too (either with the first row, or we assume that in
-                # block mode the signature is included).
-                # Write the trailer only if we are sending rows (with the
-                # assumption that who is copying binary data is sending the
-                # whole format).
-                self._write_buffer += _binary_trailer
-
-        buffer, self._write_buffer = self._write_buffer, bytearray()
-        return buffer
+        return row
 
     def _end_copy_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
         bmsg: Optional[bytes]
@@ -181,26 +133,6 @@ class BaseCopy(Generic[ConnectionType]):
         self.cursor._rowcount = nrows if nrows is not None else -1
         self._finished = True
 
-    # Support methods
-
-    def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
-        if isinstance(data, bytes):
-            return data
-
-        elif isinstance(data, str):
-            if self._pgresult.binary_tuples == BINARY:
-                raise TypeError(
-                    "cannot copy str data in binary mode: use bytes instead"
-                )
-            return data.encode(self._encoding)
-
-        else:
-            raise TypeError(f"can't write {type(data).__name__}")
-
-    def _check_reuse(self) -> None:
-        if self._finished:
-            raise TypeError("copy blocks can be used only once")
-
 
 class Copy(BaseCopy["Connection"]):
     """Manage a :sql:`COPY` operation."""
@@ -215,7 +147,7 @@ class Copy(BaseCopy["Connection"]):
         self._worker: Optional[threading.Thread] = None
 
     def __enter__(self) -> "Copy":
-        self._check_reuse()
+        self._enter()
         return self
 
     def __exit__(
@@ -275,12 +207,12 @@ class Copy(BaseCopy["Connection"]):
         If the :sql:`COPY` is in binary format *buffer* must be `!bytes`. In
         text mode it can be either `!bytes` or `!str`.
         """
-        data = self._format_write(buffer)
+        data = self.formatter.write(buffer)
         self._write(data)
 
     def write_row(self, row: Sequence[Any]) -> None:
         """Write a record to a table after a :sql:`COPY FROM` operation."""
-        data = self._format_write_row(row)
+        data = self.formatter.write_row(row)
         self._write(data)
 
     def finish(self, exc: Optional[BaseException]) -> None:
@@ -325,7 +257,7 @@ class Copy(BaseCopy["Connection"]):
         self._queue.put(data)
 
     def _write_end(self) -> None:
-        data = self._format_end()
+        data = self.formatter.end()
         self._write(data)
         self._queue.put(None)
 
@@ -347,7 +279,7 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
         self._worker: Optional[asyncio.Future[None]] = None
 
     async def __aenter__(self) -> "AsyncCopy":
-        self._check_reuse()
+        self._enter()
         return self
 
     async def __aexit__(
@@ -379,11 +311,11 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
         return await self.connection.wait(self._read_row_gen())
 
     async def write(self, buffer: Union[str, bytes]) -> None:
-        data = self._format_write(buffer)
+        data = self.formatter.write(buffer)
         await self._write(data)
 
     async def write_row(self, row: Sequence[Any]) -> None:
-        data = self._format_write_row(row)
+        data = self.formatter.write_row(row)
         await self._write(data)
 
     async def finish(self, exc: Optional[BaseException]) -> None:
@@ -420,7 +352,7 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
         await self._queue.put(data)
 
     async def _write_end(self) -> None:
-        data = self._format_end()
+        data = self.formatter.end()
         await self._write(data)
         await self._queue.put(None)
 
@@ -429,6 +361,159 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
             self._worker = None  # break reference loops if any
 
 
+class Formatter(ABC):
+    """
+    A class which understand a copy format (text, binary).
+    """
+
+    format: pq.Format
+
+    # Size of data to accumulate before sending it down the network
+    BUFFER_SIZE = 32 * 1024
+
+    def __init__(self, transformer: Transformer):
+        self.transformer = transformer
+        self._write_buffer = bytearray()
+        self._row_mode = False  # true if the user is using send_row()
+
+    @abstractmethod
+    def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]:
+        ...
+
+    @abstractmethod
+    def write(self, buffer: Union[str, bytes]) -> bytes:
+        ...
+
+    @abstractmethod
+    def write_row(self, row: Sequence[Any]) -> bytes:
+        ...
+
+    @abstractmethod
+    def end(self) -> bytes:
+        ...
+
+
+class TextFormatter(Formatter):
+
+    format = pq.Format.TEXT
+
+    def __init__(self, transformer: Transformer, encoding: str = "utf-8"):
+        super().__init__(transformer)
+        self._encoding = encoding
+
+    def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]:
+        if data:
+            return parse_row_text(data, self.transformer)
+        else:
+            return None
+
+    def write(self, buffer: Union[str, bytes]) -> bytes:
+        data = self._ensure_bytes(buffer)
+        self._signature_sent = True
+        return data
+
+    def write_row(self, row: Sequence[Any]) -> bytes:
+        # Note down that we are writing in row mode: it means we will have
+        # to take care of the end-of-copy marker too
+        self._row_mode = True
+
+        format_row_text(row, self.transformer, self._write_buffer)
+        if len(self._write_buffer) > self.BUFFER_SIZE:
+            buffer, self._write_buffer = self._write_buffer, bytearray()
+            return buffer
+        else:
+            return b""
+
+    def end(self) -> bytes:
+        buffer, self._write_buffer = self._write_buffer, bytearray()
+        return buffer
+
+    def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
+        if isinstance(data, bytes):
+            return data
+
+        elif isinstance(data, str):
+            return data.encode(self._encoding)
+
+        else:
+            raise TypeError(f"can't write {type(data).__name__}")
+
+
+class BinaryFormatter(Formatter):
+
+    format = pq.Format.BINARY
+
+    def __init__(self, transformer: Transformer):
+        super().__init__(transformer)
+        self._signature_sent = False
+
+    def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]:
+        if not self._signature_sent:
+            if data[: len(_binary_signature)] != _binary_signature:
+                raise e.DataError(
+                    "binary copy doesn't start with the expected signature"
+                )
+            self._signature_sent = True
+            data = data[len(_binary_signature) :]
+
+        elif data == _binary_trailer:
+            return None
+
+        return parse_row_binary(data, self.transformer)
+
+    def write(self, buffer: Union[str, bytes]) -> bytes:
+        data = self._ensure_bytes(buffer)
+        self._signature_sent = True
+        return data
+
+    def write_row(self, row: Sequence[Any]) -> bytes:
+        # Note down that we are writing in row mode: it means we will have
+        # to take care of the end-of-copy marker too
+        self._row_mode = True
+
+        if not self._signature_sent:
+            self._write_buffer += _binary_signature
+            self._signature_sent = True
+
+        format_row_binary(row, self.transformer, self._write_buffer)
+        if len(self._write_buffer) > self.BUFFER_SIZE:
+            buffer, self._write_buffer = self._write_buffer, bytearray()
+            return buffer
+        else:
+            return b""
+
+    def end(self) -> bytes:
+        # If we have sent no data we need to send the signature
+        # and the trailer
+        if not self._signature_sent:
+            self._write_buffer += _binary_signature
+            self._write_buffer += _binary_trailer
+
+        elif self._row_mode:
+            # if we have sent data already, we have sent the signature
+            # too (either with the first row, or we assume that in
+            # block mode the signature is included).
+            # Write the trailer only if we are sending rows (with the
+            # assumption that who is copying binary data is sending the
+            # whole format).
+            self._write_buffer += _binary_trailer
+
+        buffer, self._write_buffer = self._write_buffer, bytearray()
+        return buffer
+
+    def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
+        if isinstance(data, bytes):
+            return data
+
+        elif isinstance(data, str):
+            raise TypeError(
+                "cannot copy str data in binary mode: use bytes instead"
+            )
+
+        else:
+            raise TypeError(f"can't write {type(data).__name__}")
+
+
 def _format_row_text(
     row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
 ) -> bytearray: