]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Dropped async/sync code duplication using high level generators
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 30 Dec 2020 15:31:13 +0000 (16:31 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 8 Jan 2021 01:26:53 +0000 (02:26 +0100)
psycopg3/psycopg3/copy.py

index be5a7efba3776504ac255b132d926e38c0daf9ed..1b0bc17a99c358885f41e94210a86c9d43db2b6a 100644 (file)
@@ -12,7 +12,7 @@ from types import TracebackType
 
 from . import pq
 from .pq import Format, ExecStatus
-from .proto import ConnectionType
+from .proto import ConnectionType, PQGen
 from .generators import copy_from, copy_to, copy_end
 
 if TYPE_CHECKING:
@@ -47,6 +47,56 @@ class BaseCopy(Generic[ConnectionType]):
         info = pq.misc.connection_summary(self.connection.pgconn)
         return f"<{cls} {info} at 0x{id(self):x}>"
 
+    # High level copy protocol generators (state change of the Copy object)
+
+    def _read_gen(self) -> PQGen[bytes]:
+        if self._finished:
+            return b""
+
+        conn = self.connection
+        res = yield from copy_from(conn.pgconn)
+        if isinstance(res, bytes):
+            return res
+
+        # res is the final PGresult
+        self._finished = True
+        nrows = res.command_tuples
+        self.cursor._rowcount = nrows if nrows is not None else -1
+        return b""
+
+    def _write_gen(self, buffer: Union[str, bytes]) -> PQGen[None]:
+        conn = self.connection
+        yield from copy_to(conn.pgconn, self._ensure_bytes(buffer))
+
+    def _finish_gen(self, error: str = "") -> PQGen[None]:
+        conn = self.connection
+        berr = error.encode(conn.client_encoding, "replace") if error else None
+        res = yield from copy_end(conn.pgconn, berr)
+        nrows = res.command_tuples
+        self.cursor._rowcount = nrows if nrows is not None else -1
+        self._finished = True
+
+    def _exit_gen(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+    ) -> PQGen[None]:
+        # no-op in COPY TO
+        if self._pgresult.status == ExecStatus.COPY_OUT:
+            return
+
+        if not exc_type:
+            if self.format == Format.BINARY and not self._first_row:
+                # send EOF only if we copied binary rows (_first_row is False)
+                yield from self._write_gen(b"\xff\xff")
+            yield from self._finish_gen()
+        else:
+            yield from self._finish_gen(
+                f"error from Python: {exc_type.__qualname__} - {exc_val}"
+            )
+
+    # Support methods
+
     def _format_row(self, row: Sequence[Any]) -> bytes:
         """Convert a Python sequence to the data to send for copy"""
         out: List[Optional[bytes]] = []
@@ -111,6 +161,10 @@ class BaseCopy(Generic[ConnectionType]):
         else:
             raise TypeError(f"can't write {type(data).__name__}")
 
+    def _check_reuse(self) -> None:
+        if self._finished:
+            raise TypeError("copy blocks can be used only once")
+
 
 def _bsrepl_sub(
     m: Match[bytes],
@@ -140,19 +194,7 @@ class Copy(BaseCopy["Connection"]):
 
         Return an empty bytes string when the data is finished.
         """
-        if self._finished:
-            return b""
-
-        conn = self.connection
-        res = conn.wait(copy_from(conn.pgconn))
-        if isinstance(res, bytes):
-            return res
-
-        # res is the final PGresult
-        self._finished = True
-        nrows = res.command_tuples
-        self.cursor._rowcount = nrows if nrows is not None else -1
-        return b""
+        return self.connection.wait(self._read_gen())
 
     def write(self, buffer: Union[str, bytes]) -> None:
         """Write a block of data after a :sql:`COPY FROM` operation.
@@ -160,26 +202,19 @@ class Copy(BaseCopy["Connection"]):
         If the COPY is in binary format *buffer* must be `!bytes`. In text mode
         it can be either `!bytes` or `!str`.
         """
-        conn = self.connection
-        conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer)))
+        self.connection.wait(self._write_gen(buffer))
 
     def write_row(self, row: Sequence[Any]) -> None:
         """Write a record after a :sql:`COPY FROM` operation."""
         data = self._format_row(row)
-        self.write(data)
+        self.connection.wait(self._write_gen(data))
 
     def _finish(self, error: str = "") -> None:
         """Terminate a :sql:`COPY FROM` operation."""
-        conn = self.connection
-        berr = error.encode(conn.client_encoding, "replace") if error else None
-        res = conn.wait(copy_end(conn.pgconn, berr))
-        nrows = res.command_tuples
-        self.cursor._rowcount = nrows if nrows is not None else -1
-        self._finished = True
+        self.connection.wait(self._finish_gen(error))
 
     def __enter__(self) -> "Copy":
-        if self._finished:
-            raise TypeError("copy blocks can be used only once")
+        self._check_reuse()
         return self
 
     def __exit__(
@@ -188,19 +223,7 @@ class Copy(BaseCopy["Connection"]):
         exc_val: Optional[BaseException],
         exc_tb: Optional[TracebackType],
     ) -> None:
-        # no-op in COPY TO
-        if self._pgresult.status == ExecStatus.COPY_OUT:
-            return
-
-        if not exc_type:
-            if self.format == Format.BINARY and not self._first_row:
-                # send EOF only if we copied binary rows (_first_row is False)
-                self.write(b"\xff\xff")
-            self._finish()
-        else:
-            self._finish(
-                f"error from Python: {exc_type.__qualname__} - {exc_val}"
-            )
+        self.connection.wait(self._exit_gen(exc_type, exc_val))
 
     def __iter__(self) -> Iterator[bytes]:
         while True:
@@ -216,39 +239,20 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
     __module__ = "psycopg3"
 
     async def read(self) -> bytes:
-        if self._finished:
-            return b""
-
-        conn = self.connection
-        res = await conn.wait(copy_from(conn.pgconn))
-        if isinstance(res, bytes):
-            return res
-
-        # res is the final PGresult
-        self._finished = True
-        nrows = res.command_tuples
-        self.cursor._rowcount = nrows if nrows is not None else -1
-        return b""
+        return await self.connection.wait(self._read_gen())
 
     async def write(self, buffer: Union[str, bytes]) -> None:
-        conn = self.connection
-        await conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer)))
+        await self.connection.wait(self._write_gen(buffer))
 
     async def write_row(self, row: Sequence[Any]) -> None:
         data = self._format_row(row)
-        await self.write(data)
+        await self.connection.wait(self._write_gen(data))
 
     async def _finish(self, error: str = "") -> None:
-        conn = self.connection
-        berr = error.encode(conn.client_encoding, "replace") if error else None
-        res = await conn.wait(copy_end(conn.pgconn, berr))
-        nrows = res.command_tuples
-        self.cursor._rowcount = nrows if nrows is not None else -1
-        self._finished = True
+        await self.connection.wait(self._finish_gen(error))
 
     async def __aenter__(self) -> "AsyncCopy":
-        if self._finished:
-            raise TypeError("copy blocks can be used only once")
+        self._check_reuse()
         return self
 
     async def __aexit__(
@@ -257,19 +261,7 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
         exc_val: Optional[BaseException],
         exc_tb: Optional[TracebackType],
     ) -> None:
-        # no-op in COPY TO
-        if self._pgresult.status == ExecStatus.COPY_OUT:
-            return
-
-        if not exc_type:
-            if self.format == Format.BINARY and not self._first_row:
-                # send EOF only if we copied binary rows (_first_row is False)
-                await self.write(b"\xff\xff")
-            await self._finish()
-        else:
-            await self._finish(
-                f"error from Python: {exc_type.__qualname__} - {exc_val}"
-            )
+        await self.connection.wait(self._exit_gen(exc_type, exc_val))
 
     async def __aiter__(self) -> AsyncIterator[bytes]:
         while True: