From: Daniele Varrazzo Date: Sun, 15 Nov 2020 01:57:14 +0000 (+0000) Subject: Copy object interface simplified X-Git-Tag: 3.0.dev0~351^2~14 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7dfa8fe1afa9b7643820ce420fe3b0791af306b3;p=thirdparty%2Fpsycopg.git Copy object interface simplified Dropped methods and code paths no more useful now that the user can only interact with Copy as a context. --- diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index 4dfc95526..494d77196 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -24,48 +24,23 @@ class BaseCopy(Generic[ConnectionType]): self.connection = connection self.transformer = transformer - self.format = self.pgresult.binary_tuples + assert ( + self.transformer.pgresult + ), "The Transformer doesn't have a PGresult set" + self._pgresult: "PGresult" = self.transformer.pgresult + + self.format = self._pgresult.binary_tuples + self._encoding = self.connection.client_encoding self._first_row = True self._finished = False - self._encoding: str = "" if self.format == Format.TEXT: - self._format_row = self._format_row_text - else: - self._format_row = self._format_row_binary - - @property - def finished(self) -> bool: - return self._finished - - @property - def pgresult(self) -> "PGresult": - pgresult = self.transformer.pgresult - assert pgresult, "The Transformer doesn't have a PGresult set" - return pgresult - - def _ensure_bytes(self, data: Union[bytes, str]) -> bytes: - if isinstance(data, bytes): - return data - - elif isinstance(data, str): - if self._encoding: - return data.encode(self._encoding) - - if ( - self.pgresult is None - or self.pgresult.binary_tuples == Format.BINARY - ): - raise TypeError( - "cannot copy str data in binary mode: use bytes instead" - ) - self._encoding = self.connection.client_encoding - return data.encode(self._encoding) - + self._format_copy_row = self._format_row_text else: - raise TypeError(f"can't write {type(data).__name__}") + self._format_copy_row = self._format_row_binary - def format_row(self, row: Sequence[Any]) -> bytes: + def _format_row(self, row: Sequence[Any]) -> bytes: + """Convert a Python sequence to the data to send for copy""" out: List[Optional[bytes]] = [] for item in row: if item is not None: @@ -73,9 +48,10 @@ class BaseCopy(Generic[ConnectionType]): out.append(dumper.dump(item)) else: out.append(None) - return self._format_row(out) + return self._format_copy_row(out) def _format_row_text(self, row: Sequence[Optional[bytes]]) -> bytes: + """Convert a row of adapted data to the data to send for text copy""" return ( b"\t".join( _bsrepl_re.sub(_bsrepl_sub, item) @@ -92,6 +68,7 @@ class BaseCopy(Generic[ConnectionType]): __int2_struct: struct.Struct = struct.Struct("!h"), __int4_struct: struct.Struct = struct.Struct("!i"), ) -> bytes: + """Convert a row of adapted data to the data to send for binary copy""" out = [] if self._first_row: out.append( @@ -112,6 +89,20 @@ class BaseCopy(Generic[ConnectionType]): return b"".join(out) + def _ensure_bytes(self, data: Union[bytes, str]) -> bytes: + if isinstance(data, bytes): + return data + + elif isinstance(data, str): + if self._pgresult.binary_tuples == Format.BINARY: + raise TypeError( + "cannot copy str data in binary mode: use bytes instead" + ) + return data.encode(self._encoding) + + else: + raise TypeError(f"can't write {type(data).__name__}") + def _bsrepl_sub( m: Match[bytes], @@ -156,10 +147,10 @@ class Copy(BaseCopy["Connection"]): def write_row(self, row: Sequence[Any]) -> None: """Write a record after a :sql:`COPY FROM` operation.""" - data = self.format_row(row) + data = self._format_row(row) self.write(data) - def finish(self, error: str = "") -> None: + def _finish(self, error: str = "") -> None: """Terminate a :sql:`COPY FROM` operation.""" conn = self.connection berr = error.encode(conn.client_encoding, "replace") if error else None @@ -176,21 +167,21 @@ class Copy(BaseCopy["Connection"]): exc_tb: Optional[TracebackType], ) -> None: # no-op in COPY TO - if self.pgresult.status == ExecStatus.COPY_OUT: + if self._pgresult.status == ExecStatus.COPY_OUT: return if not exc_type: if self.format == Format.BINARY and not self._first_row: # send EOF only if we copied binary rows (_first_row is False) self.write(b"\xff\xff") - self.finish() + self._finish() else: - self.finish( + self._finish( f"error from Python: {exc_type.__qualname__} - {exc_val}" ) def __iter__(self) -> Iterator[bytes]: - while 1: + while True: data = self.read() if data is None: break @@ -216,10 +207,10 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): await conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer))) async def write_row(self, row: Sequence[Any]) -> None: - data = self.format_row(row) + data = self._format_row(row) await self.write(data) - async def finish(self, error: str = "") -> None: + async def _finish(self, error: str = "") -> None: conn = self.connection berr = error.encode(conn.client_encoding, "replace") if error else None await conn.wait(copy_end(conn.pgconn, berr)) @@ -235,21 +226,21 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): exc_tb: Optional[TracebackType], ) -> None: # no-op in COPY TO - if self.pgresult.status == ExecStatus.COPY_OUT: + if self._pgresult.status == ExecStatus.COPY_OUT: return if not exc_type: if self.format == Format.BINARY and not self._first_row: # send EOF only if we copied binary rows (_first_row is False) await self.write(b"\xff\xff") - await self.finish() + await self._finish() else: - await self.finish( + await self._finish( f"error from Python: {exc_type.__qualname__} - {exc_val}" ) async def __aiter__(self) -> AsyncIterator[bytes]: - while 1: + while True: data = await self.read() if data is None: break