From: Daniele Varrazzo Date: Wed, 30 Dec 2020 15:31:13 +0000 (+0100) Subject: Dropped async/sync code duplication using high level generators X-Git-Tag: 3.0.dev0~223 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ef4bc626e9be098bf9c6071089e0600424f96143;p=thirdparty%2Fpsycopg.git Dropped async/sync code duplication using high level generators --- diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index be5a7efba..1b0bc17a9 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -12,7 +12,7 @@ from types import TracebackType from . import pq from .pq import Format, ExecStatus -from .proto import ConnectionType +from .proto import ConnectionType, PQGen from .generators import copy_from, copy_to, copy_end if TYPE_CHECKING: @@ -47,6 +47,56 @@ class BaseCopy(Generic[ConnectionType]): info = pq.misc.connection_summary(self.connection.pgconn) return f"<{cls} {info} at 0x{id(self):x}>" + # High level copy protocol generators (state change of the Copy object) + + def _read_gen(self) -> PQGen[bytes]: + if self._finished: + return b"" + + conn = self.connection + res = yield from copy_from(conn.pgconn) + if isinstance(res, bytes): + return res + + # res is the final PGresult + self._finished = True + nrows = res.command_tuples + self.cursor._rowcount = nrows if nrows is not None else -1 + return b"" + + def _write_gen(self, buffer: Union[str, bytes]) -> PQGen[None]: + conn = self.connection + yield from copy_to(conn.pgconn, self._ensure_bytes(buffer)) + + def _finish_gen(self, error: str = "") -> PQGen[None]: + conn = self.connection + berr = error.encode(conn.client_encoding, "replace") if error else None + res = yield from copy_end(conn.pgconn, berr) + nrows = res.command_tuples + self.cursor._rowcount = nrows if nrows is not None else -1 + self._finished = True + + def _exit_gen( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + ) -> PQGen[None]: + # no-op in COPY TO + if self._pgresult.status == ExecStatus.COPY_OUT: + return + + if not exc_type: + if self.format == Format.BINARY and not self._first_row: + # send EOF only if we copied binary rows (_first_row is False) + yield from self._write_gen(b"\xff\xff") + yield from self._finish_gen() + else: + yield from self._finish_gen( + f"error from Python: {exc_type.__qualname__} - {exc_val}" + ) + + # Support methods + def _format_row(self, row: Sequence[Any]) -> bytes: """Convert a Python sequence to the data to send for copy""" out: List[Optional[bytes]] = [] @@ -111,6 +161,10 @@ class BaseCopy(Generic[ConnectionType]): else: raise TypeError(f"can't write {type(data).__name__}") + def _check_reuse(self) -> None: + if self._finished: + raise TypeError("copy blocks can be used only once") + def _bsrepl_sub( m: Match[bytes], @@ -140,19 +194,7 @@ class Copy(BaseCopy["Connection"]): Return an empty bytes string when the data is finished. """ - if self._finished: - return b"" - - conn = self.connection - res = conn.wait(copy_from(conn.pgconn)) - if isinstance(res, bytes): - return res - - # res is the final PGresult - self._finished = True - nrows = res.command_tuples - self.cursor._rowcount = nrows if nrows is not None else -1 - return b"" + return self.connection.wait(self._read_gen()) def write(self, buffer: Union[str, bytes]) -> None: """Write a block of data after a :sql:`COPY FROM` operation. @@ -160,26 +202,19 @@ class Copy(BaseCopy["Connection"]): If the COPY is in binary format *buffer* must be `!bytes`. In text mode it can be either `!bytes` or `!str`. """ - conn = self.connection - conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer))) + self.connection.wait(self._write_gen(buffer)) def write_row(self, row: Sequence[Any]) -> None: """Write a record after a :sql:`COPY FROM` operation.""" data = self._format_row(row) - self.write(data) + self.connection.wait(self._write_gen(data)) def _finish(self, error: str = "") -> None: """Terminate a :sql:`COPY FROM` operation.""" - conn = self.connection - berr = error.encode(conn.client_encoding, "replace") if error else None - res = conn.wait(copy_end(conn.pgconn, berr)) - nrows = res.command_tuples - self.cursor._rowcount = nrows if nrows is not None else -1 - self._finished = True + self.connection.wait(self._finish_gen(error)) def __enter__(self) -> "Copy": - if self._finished: - raise TypeError("copy blocks can be used only once") + self._check_reuse() return self def __exit__( @@ -188,19 +223,7 @@ class Copy(BaseCopy["Connection"]): exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: - # no-op in COPY TO - if self._pgresult.status == ExecStatus.COPY_OUT: - return - - if not exc_type: - if self.format == Format.BINARY and not self._first_row: - # send EOF only if we copied binary rows (_first_row is False) - self.write(b"\xff\xff") - self._finish() - else: - self._finish( - f"error from Python: {exc_type.__qualname__} - {exc_val}" - ) + self.connection.wait(self._exit_gen(exc_type, exc_val)) def __iter__(self) -> Iterator[bytes]: while True: @@ -216,39 +239,20 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): __module__ = "psycopg3" async def read(self) -> bytes: - if self._finished: - return b"" - - conn = self.connection - res = await conn.wait(copy_from(conn.pgconn)) - if isinstance(res, bytes): - return res - - # res is the final PGresult - self._finished = True - nrows = res.command_tuples - self.cursor._rowcount = nrows if nrows is not None else -1 - return b"" + return await self.connection.wait(self._read_gen()) async def write(self, buffer: Union[str, bytes]) -> None: - conn = self.connection - await conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer))) + await self.connection.wait(self._write_gen(buffer)) async def write_row(self, row: Sequence[Any]) -> None: data = self._format_row(row) - await self.write(data) + await self.connection.wait(self._write_gen(data)) async def _finish(self, error: str = "") -> None: - conn = self.connection - berr = error.encode(conn.client_encoding, "replace") if error else None - res = await conn.wait(copy_end(conn.pgconn, berr)) - nrows = res.command_tuples - self.cursor._rowcount = nrows if nrows is not None else -1 - self._finished = True + await self.connection.wait(self._finish_gen(error)) async def __aenter__(self) -> "AsyncCopy": - if self._finished: - raise TypeError("copy blocks can be used only once") + self._check_reuse() return self async def __aexit__( @@ -257,19 +261,7 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: - # no-op in COPY TO - if self._pgresult.status == ExecStatus.COPY_OUT: - return - - if not exc_type: - if self.format == Format.BINARY and not self._first_row: - # send EOF only if we copied binary rows (_first_row is False) - await self.write(b"\xff\xff") - await self._finish() - else: - await self._finish( - f"error from Python: {exc_type.__qualname__} - {exc_val}" - ) + await self.connection.wait(self._exit_gen(exc_type, exc_val)) async def __aiter__(self) -> AsyncIterator[bytes]: while True: