]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Copy objects refactoring to isolate the row format function
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 31 Dec 2020 19:25:57 +0000 (20:25 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 8 Jan 2021 01:26:53 +0000 (02:26 +0100)
psycopg3/psycopg3/copy.py

index fe39b4656be1000646d51c2d5cf5af6bfeaa0372..99959447148c6d588411e6b65ba691305112fa60 100644 (file)
@@ -26,6 +26,7 @@ class BaseCopy(Generic[ConnectionType]):
         self.cursor = cursor
         self.connection = cursor.connection
         self.transformer = cursor._transformer
+        self._pgconn = self.connection.pgconn
 
         assert (
             self.transformer.pgresult
@@ -39,13 +40,13 @@ class BaseCopy(Generic[ConnectionType]):
         self._finished = False
 
         if self.format == Format.TEXT:
-            self._format_copy_row = self._format_row_text
+            self._format_copy_row = format_row_text
         else:
-            self._format_copy_row = self._format_row_binary
+            self._format_copy_row = format_row_binary
 
     def __repr__(self) -> str:
         cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
-        info = pq.misc.connection_summary(self.connection.pgconn)
+        info = pq.misc.connection_summary(self._pgconn)
         return f"<{cls} {info} at 0x{id(self):x}>"
 
     # High level copy protocol generators (state change of the Copy object)
@@ -54,8 +55,7 @@ class BaseCopy(Generic[ConnectionType]):
         if self._finished:
             return b""
 
-        conn = self.connection
-        res = yield from copy_from(conn.pgconn)
+        res = yield from copy_from(self._pgconn)
         if isinstance(res, bytes):
             return res
 
@@ -66,18 +66,30 @@ class BaseCopy(Generic[ConnectionType]):
         return b""
 
     def _write_gen(self, buffer: Union[str, bytes]) -> PQGen[None]:
-        conn = self.connection
         # if write() was called, assume the header was sent together with the
-        # first block of data (either because we added it to the first row
-        # or, if the user is copying blocks, assume the blocks contain
-        # the header).
+        # first block of data.
         self._signature_sent = True
-        yield from copy_to(conn.pgconn, self._ensure_bytes(buffer))
+        yield from copy_to(self._pgconn, self._ensure_bytes(buffer))
+
+    def _write_row_gen(self, row: Sequence[Any]) -> PQGen[None]:
+        # Note down that we are writing in row mode: it means we will have
+        # to take care of the end-of-copy marker too
+        self._row_mode = True
+
+        data = self._format_row(row)
+        if self.format == Format.BINARY and not self._signature_sent:
+            yield from copy_to(self._pgconn, _binary_signature)
+            self._signature_sent = True
+
+        yield from copy_to(self._pgconn, data)
 
     def _finish_gen(self, error: str = "") -> PQGen[None]:
-        conn = self.connection
-        berr = error.encode(conn.client_encoding, "replace") if error else None
-        res = yield from copy_end(conn.pgconn, berr)
+        berr = (
+            error.encode(self.connection.client_encoding, "replace")
+            if error
+            else None
+        )
+        res = yield from copy_end(self._pgconn, berr)
         nrows = res.command_tuples
         self.cursor._rowcount = nrows if nrows is not None else -1
         self._finished = True
@@ -102,8 +114,8 @@ class BaseCopy(Generic[ConnectionType]):
             # If we have sent no data we need to send the signature
             # and the trailer
             if not self._signature_sent:
-                yield from self._write_gen(self._binary_signature)
-                yield from self._write_gen(self._binary_trailer)
+                yield from copy_to(self._pgconn, _binary_signature)
+                yield from copy_to(self._pgconn, _binary_trailer)
             elif self._row_mode:
                 # if we have sent data already, we have sent the signature too
                 # (either with the first row, or we assume that in block mode
@@ -111,7 +123,7 @@ class BaseCopy(Generic[ConnectionType]):
                 # Write the trailer only if we are sending rows (with the
                 # assumption that who is copying binary data is sending the
                 # whole format).
-                yield from self._write_gen(self._binary_trailer)
+                yield from copy_to(self._pgconn, _binary_trailer)
 
         yield from self._finish_gen()
 
@@ -129,53 +141,6 @@ class BaseCopy(Generic[ConnectionType]):
 
         return self._format_copy_row(out)
 
-    def _format_row_text(self, row: Sequence[Optional[bytes]]) -> bytes:
-        """Convert a row of adapted data to the data to send for text copy"""
-        return (
-            b"\t".join(
-                _bsrepl_re.sub(_bsrepl_sub, item)
-                if item is not None
-                else br"\N"
-                for item in row
-            )
-            + b"\n"
-        )
-
-    def _format_row_binary(
-        self,
-        row: Sequence[Optional[bytes]],
-        __int2_struct: struct.Struct = struct.Struct("!h"),
-        __int4_struct: struct.Struct = struct.Struct("!i"),
-    ) -> bytes:
-        """Convert a row of adapted data to the data to send for binary copy"""
-        out = []
-        if not self._signature_sent:
-            out.append(self._binary_signature)
-            self._signature_sent = True
-
-            # Note down that we are writing in row mode: it means we will have
-            # to take care of the end-of-copy marker too
-            self._row_mode = True
-
-        out.append(__int2_struct.pack(len(row)))
-        for item in row:
-            if item is not None:
-                out.append(__int4_struct.pack(len(item)))
-                out.append(item)
-            else:
-                out.append(self._binary_null)
-
-        return b"".join(out)
-
-    _binary_signature = (
-        # Signature, flags, extra length
-        b"PGCOPY\n\xff\r\n\0"
-        b"\x00\x00\x00\x00"
-        b"\x00\x00\x00\x00"
-    )
-    _binary_trailer = b"\xff\xff"
-    _binary_null = b"\xff\xff\xff\xff"
-
     def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
         if isinstance(data, bytes):
             return data
@@ -195,24 +160,6 @@ class BaseCopy(Generic[ConnectionType]):
             raise TypeError("copy blocks can be used only once")
 
 
-def _bsrepl_sub(
-    m: Match[bytes],
-    __map: Dict[bytes, bytes] = {
-        b"\b": b"\\b",
-        b"\t": b"\\t",
-        b"\n": b"\\n",
-        b"\v": b"\\v",
-        b"\f": b"\\f",
-        b"\r": b"\\r",
-        b"\\": b"\\\\",
-    },
-) -> bytes:
-    return __map[m.group(0)]
-
-
-_bsrepl_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
-
-
 class Copy(BaseCopy["Connection"]):
     """Manage a :sql:`COPY` operation."""
 
@@ -235,8 +182,7 @@ class Copy(BaseCopy["Connection"]):
 
     def write_row(self, row: Sequence[Any]) -> None:
         """Write a record after a :sql:`COPY FROM` operation."""
-        data = self._format_row(row)
-        self.connection.wait(self._write_gen(data))
+        self.connection.wait(self._write_row_gen(row))
 
     def _finish(self, error: str = "") -> None:
         """Terminate a :sql:`COPY FROM` operation."""
@@ -274,8 +220,7 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
         await self.connection.wait(self._write_gen(buffer))
 
     async def write_row(self, row: Sequence[Any]) -> None:
-        data = self._format_row(row)
-        await self.connection.wait(self._write_gen(data))
+        await self.connection.wait(self._write_row_gen(row))
 
     async def _finish(self, error: str = "") -> None:
         await self.connection.wait(self._finish_gen(error))
@@ -298,3 +243,60 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
             if not data:
                 break
             yield data
+
+
+def format_row_text(row: Sequence[Optional[bytes]]) -> bytes:
+    """Convert a row of adapted data to the data to send for text copy"""
+    return (
+        b"\t".join(
+            _bsrepl_re.sub(_bsrepl_sub, item) if item is not None else br"\N"
+            for item in row
+        )
+        + b"\n"
+    )
+
+
+def format_row_binary(
+    row: Sequence[Optional[bytes]],
+    __int2_struct: struct.Struct = struct.Struct("!h"),
+    __int4_struct: struct.Struct = struct.Struct("!i"),
+) -> bytes:
+    """Convert a row of adapted data to the data to send for binary copy"""
+    out = []
+    out.append(__int2_struct.pack(len(row)))
+    for item in row:
+        if item is not None:
+            out.append(__int4_struct.pack(len(item)))
+            out.append(item)
+        else:
+            out.append(_binary_null)
+
+    return b"".join(out)
+
+
+_binary_signature = (
+    # Signature, flags, extra length
+    b"PGCOPY\n\xff\r\n\0"
+    b"\x00\x00\x00\x00"
+    b"\x00\x00\x00\x00"
+)
+_binary_trailer = b"\xff\xff"
+_binary_null = b"\xff\xff\xff\xff"
+
+
+def _bsrepl_sub(
+    m: Match[bytes],
+    __map: Dict[bytes, bytes] = {
+        b"\b": b"\\b",
+        b"\t": b"\\t",
+        b"\n": b"\\n",
+        b"\v": b"\\v",
+        b"\f": b"\\f",
+        b"\r": b"\\r",
+        b"\\": b"\\\\",
+    },
+) -> bytes:
+    return __map[m.group(0)]
+
+
+_bsrepl_re = re.compile(b"[\b\t\n\v\f\r\\\\]")