From: Daniele Varrazzo Date: Thu, 12 Nov 2020 19:37:31 +0000 (+0000) Subject: Cleanup of Copy attributes and parameters X-Git-Tag: 3.0.dev0~370 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6075fd2640290ea279e981757eb6aabb98b16c7d;p=thirdparty%2Fpsycopg.git Cleanup of Copy attributes and parameters --- diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index c62012870..e3831a6fc 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -20,16 +20,11 @@ if TYPE_CHECKING: class BaseCopy(Generic[ConnectionType]): - def __init__( - self, - connection: ConnectionType, - transformer: Transformer, - result: "PGresult", - ): + def __init__(self, connection: ConnectionType, transformer: Transformer): self.connection = connection - self._transformer = transformer - self.pgresult = result - self.format = result.binary_tuples + self.transformer = transformer + + self.format = self.pgresult.binary_tuples self._first_row = True self._finished = False self._encoding: str = "" @@ -44,13 +39,10 @@ class BaseCopy(Generic[ConnectionType]): return self._finished @property - def pgresult(self) -> Optional["PGresult"]: - return self._pgresult - - @pgresult.setter - def pgresult(self, result: Optional["PGresult"]) -> None: - self._pgresult = result - self._transformer.pgresult = result + def pgresult(self) -> "PGresult": + pgresult = self.transformer.pgresult + assert pgresult, "The Transformer doesn't have a PGresult set" + return pgresult def _ensure_bytes(self, data: Union[bytes, str]) -> bytes: if isinstance(data, bytes): @@ -77,7 +69,7 @@ class BaseCopy(Generic[ConnectionType]): out: List[Optional[bytes]] = [] for item in row: if item is not None: - dumper = self._transformer.get_dumper(item, self.format) + dumper = self.transformer.get_dumper(item, self.format) out.append(dumper.dump(item)) else: out.append(None) diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index 66a490f78..e21d51479 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -565,13 +565,10 @@ class Cursor(BaseCursor["Connection"]): self._execute_send(statement, vars, no_pqexec=True) gen = execute(self.connection.pgconn) results = self.connection.wait(gen) + self._check_copy_results(results) + self.pgresult = results[0] # will set it on the transformer too - self._check_copy_results(results) - return Copy( - connection=self.connection, - transformer=self._transformer, - result=results[0], - ) + return Copy(connection=self.connection, transformer=self._transformer) class AsyncCursor(BaseCursor["AsyncConnection"]): @@ -696,12 +693,12 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): self._execute_send(statement, vars, no_pqexec=True) gen = execute(self.connection.pgconn) results = await self.connection.wait(gen) + self._check_copy_results(results) + self.pgresult = results[0] # will set it on the transformer too - self._check_copy_results(results) return AsyncCopy( connection=self.connection, transformer=self._transformer, - result=results[0], )