]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Copy object interface simplified
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 15 Nov 2020 01:57:14 +0000 (01:57 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 15 Nov 2020 02:35:05 +0000 (02:35 +0000)
Dropped methods and code paths no more useful now that the user can only
interact with Copy as a context.

psycopg3/psycopg3/copy.py

index 4dfc95526d4ce0329574f49b2edf8f9252857e23..494d7719624f9bf9a753ffe4582cf1bf5b0f4c56 100644 (file)
@@ -24,48 +24,23 @@ class BaseCopy(Generic[ConnectionType]):
         self.connection = connection
         self.transformer = transformer
 
-        self.format = self.pgresult.binary_tuples
+        assert (
+            self.transformer.pgresult
+        ), "The Transformer doesn't have a PGresult set"
+        self._pgresult: "PGresult" = self.transformer.pgresult
+
+        self.format = self._pgresult.binary_tuples
+        self._encoding = self.connection.client_encoding
         self._first_row = True
         self._finished = False
-        self._encoding: str = ""
 
         if self.format == Format.TEXT:
-            self._format_row = self._format_row_text
-        else:
-            self._format_row = self._format_row_binary
-
-    @property
-    def finished(self) -> bool:
-        return self._finished
-
-    @property
-    def pgresult(self) -> "PGresult":
-        pgresult = self.transformer.pgresult
-        assert pgresult, "The Transformer doesn't have a PGresult set"
-        return pgresult
-
-    def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
-        if isinstance(data, bytes):
-            return data
-
-        elif isinstance(data, str):
-            if self._encoding:
-                return data.encode(self._encoding)
-
-            if (
-                self.pgresult is None
-                or self.pgresult.binary_tuples == Format.BINARY
-            ):
-                raise TypeError(
-                    "cannot copy str data in binary mode: use bytes instead"
-                )
-            self._encoding = self.connection.client_encoding
-            return data.encode(self._encoding)
-
+            self._format_copy_row = self._format_row_text
         else:
-            raise TypeError(f"can't write {type(data).__name__}")
+            self._format_copy_row = self._format_row_binary
 
-    def format_row(self, row: Sequence[Any]) -> bytes:
+    def _format_row(self, row: Sequence[Any]) -> bytes:
+        """Convert a Python sequence to the data to send for copy"""
         out: List[Optional[bytes]] = []
         for item in row:
             if item is not None:
@@ -73,9 +48,10 @@ class BaseCopy(Generic[ConnectionType]):
                 out.append(dumper.dump(item))
             else:
                 out.append(None)
-        return self._format_row(out)
+        return self._format_copy_row(out)
 
     def _format_row_text(self, row: Sequence[Optional[bytes]]) -> bytes:
+        """Convert a row of adapted data to the data to send for text copy"""
         return (
             b"\t".join(
                 _bsrepl_re.sub(_bsrepl_sub, item)
@@ -92,6 +68,7 @@ class BaseCopy(Generic[ConnectionType]):
         __int2_struct: struct.Struct = struct.Struct("!h"),
         __int4_struct: struct.Struct = struct.Struct("!i"),
     ) -> bytes:
+        """Convert a row of adapted data to the data to send for binary copy"""
         out = []
         if self._first_row:
             out.append(
@@ -112,6 +89,20 @@ class BaseCopy(Generic[ConnectionType]):
 
         return b"".join(out)
 
+    def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
+        if isinstance(data, bytes):
+            return data
+
+        elif isinstance(data, str):
+            if self._pgresult.binary_tuples == Format.BINARY:
+                raise TypeError(
+                    "cannot copy str data in binary mode: use bytes instead"
+                )
+            return data.encode(self._encoding)
+
+        else:
+            raise TypeError(f"can't write {type(data).__name__}")
+
 
 def _bsrepl_sub(
     m: Match[bytes],
@@ -156,10 +147,10 @@ class Copy(BaseCopy["Connection"]):
 
     def write_row(self, row: Sequence[Any]) -> None:
         """Write a record after a :sql:`COPY FROM` operation."""
-        data = self.format_row(row)
+        data = self._format_row(row)
         self.write(data)
 
-    def finish(self, error: str = "") -> None:
+    def _finish(self, error: str = "") -> None:
         """Terminate a :sql:`COPY FROM` operation."""
         conn = self.connection
         berr = error.encode(conn.client_encoding, "replace") if error else None
@@ -176,21 +167,21 @@ class Copy(BaseCopy["Connection"]):
         exc_tb: Optional[TracebackType],
     ) -> None:
         # no-op in COPY TO
-        if self.pgresult.status == ExecStatus.COPY_OUT:
+        if self._pgresult.status == ExecStatus.COPY_OUT:
             return
 
         if not exc_type:
             if self.format == Format.BINARY and not self._first_row:
                 # send EOF only if we copied binary rows (_first_row is False)
                 self.write(b"\xff\xff")
-            self.finish()
+            self._finish()
         else:
-            self.finish(
+            self._finish(
                 f"error from Python: {exc_type.__qualname__} - {exc_val}"
             )
 
     def __iter__(self) -> Iterator[bytes]:
-        while 1:
+        while True:
             data = self.read()
             if data is None:
                 break
@@ -216,10 +207,10 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
         await conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer)))
 
     async def write_row(self, row: Sequence[Any]) -> None:
-        data = self.format_row(row)
+        data = self._format_row(row)
         await self.write(data)
 
-    async def finish(self, error: str = "") -> None:
+    async def _finish(self, error: str = "") -> None:
         conn = self.connection
         berr = error.encode(conn.client_encoding, "replace") if error else None
         await conn.wait(copy_end(conn.pgconn, berr))
@@ -235,21 +226,21 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
         exc_tb: Optional[TracebackType],
     ) -> None:
         # no-op in COPY TO
-        if self.pgresult.status == ExecStatus.COPY_OUT:
+        if self._pgresult.status == ExecStatus.COPY_OUT:
             return
 
         if not exc_type:
             if self.format == Format.BINARY and not self._first_row:
                 # send EOF only if we copied binary rows (_first_row is False)
                 await self.write(b"\xff\xff")
-            await self.finish()
+            await self._finish()
         else:
-            await self.finish(
+            await self._finish(
                 f"error from Python: {exc_type.__qualname__} - {exc_val}"
             )
 
     async def __aiter__(self) -> AsyncIterator[bytes]:
-        while 1:
+        while True:
             data = await self.read()
             if data is None:
                 break