]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Cleanup of Copy attributes and parameters
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 12 Nov 2020 19:37:31 +0000 (19:37 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 12 Nov 2020 19:37:31 +0000 (19:37 +0000)
psycopg3/psycopg3/copy.py
psycopg3/psycopg3/cursor.py

index c6201287091809953784c4b77e76245c662207c3..e3831a6fc9e905a00c17095db1c11771f9466b8e 100644 (file)
@@ -20,16 +20,11 @@ if TYPE_CHECKING:
 
 
 class BaseCopy(Generic[ConnectionType]):
-    def __init__(
-        self,
-        connection: ConnectionType,
-        transformer: Transformer,
-        result: "PGresult",
-    ):
+    def __init__(self, connection: ConnectionType, transformer: Transformer):
         self.connection = connection
-        self._transformer = transformer
-        self.pgresult = result
-        self.format = result.binary_tuples
+        self.transformer = transformer
+
+        self.format = self.pgresult.binary_tuples
         self._first_row = True
         self._finished = False
         self._encoding: str = ""
@@ -44,13 +39,10 @@ class BaseCopy(Generic[ConnectionType]):
         return self._finished
 
     @property
-    def pgresult(self) -> Optional["PGresult"]:
-        return self._pgresult
-
-    @pgresult.setter
-    def pgresult(self, result: Optional["PGresult"]) -> None:
-        self._pgresult = result
-        self._transformer.pgresult = result
+    def pgresult(self) -> "PGresult":
+        pgresult = self.transformer.pgresult
+        assert pgresult, "The Transformer doesn't have a PGresult set"
+        return pgresult
 
     def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
         if isinstance(data, bytes):
@@ -77,7 +69,7 @@ class BaseCopy(Generic[ConnectionType]):
         out: List[Optional[bytes]] = []
         for item in row:
             if item is not None:
-                dumper = self._transformer.get_dumper(item, self.format)
+                dumper = self.transformer.get_dumper(item, self.format)
                 out.append(dumper.dump(item))
             else:
                 out.append(None)
index 66a490f786e91393bb437cd552d18b937fe582de..e21d51479fba871d387b22fd74eabef692ba3c7e 100644 (file)
@@ -565,13 +565,10 @@ class Cursor(BaseCursor["Connection"]):
             self._execute_send(statement, vars, no_pqexec=True)
             gen = execute(self.connection.pgconn)
             results = self.connection.wait(gen)
+            self._check_copy_results(results)
+            self.pgresult = results[0]  # will set it on the transformer too
 
-        self._check_copy_results(results)
-        return Copy(
-            connection=self.connection,
-            transformer=self._transformer,
-            result=results[0],
-        )
+        return Copy(connection=self.connection, transformer=self._transformer)
 
 
 class AsyncCursor(BaseCursor["AsyncConnection"]):
@@ -696,12 +693,12 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
             self._execute_send(statement, vars, no_pqexec=True)
             gen = execute(self.connection.pgconn)
             results = await self.connection.wait(gen)
+            self._check_copy_results(results)
+            self.pgresult = results[0]  # will set it on the transformer too
 
-        self._check_copy_results(results)
         return AsyncCopy(
             connection=self.connection,
             transformer=self._transformer,
-            result=results[0],
         )