]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Make ServerCursorHelper.format private
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 28 Oct 2021 16:11:16 +0000 (18:11 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 28 Oct 2021 16:11:16 +0000 (18:11 +0200)
Reduce duplicate code, remove a race condition (the format set outside
the lock section), make _declare_gen() interface more similar to
_execute_gen().

psycopg/psycopg/server_cursor.py

index 7c2e2e7e9b8013a8db8560a3a2ead4427bce3a58..d050a7186f5449f4278d1516ac4df77d732adc38 100644 (file)
@@ -25,7 +25,7 @@ DEFAULT_ITERSIZE = 100
 
 
 class ServerCursorHelper(Generic[ConnectionType, Row]):
-    __slots__ = ("name", "scrollable", "withhold", "described", "format")
+    __slots__ = ("name", "scrollable", "withhold", "described", "_format")
     """Helper object for common ServerCursor code.
 
     TODO: this should be a mixin, but couldn't find a way to work it
@@ -42,7 +42,7 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         self.scrollable = scrollable
         self.withhold = withhold
         self.described = False
-        self.format = pq.Format.TEXT
+        self._format = pq.Format.TEXT
 
     def _repr(self, cur: BaseCursor[ConnectionType, Row]) -> str:
         cls = f"{cur.__class__.__module__}.{cur.__class__.__qualname__}"
@@ -60,6 +60,7 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         cur: BaseCursor[ConnectionType, Row],
         query: Query,
         params: Optional[Params] = None,
+        binary: Optional[bool] = None,
     ) -> PQGen[None]:
         """Generator implementing `ServerCursor.execute()`."""
 
@@ -78,6 +79,12 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         if results[-1].status != pq.ExecStatus.COMMAND_OK:
             cur._raise_from_results(results)
 
+        # Set the format, which will be used by describe and fetch operations
+        if binary is None:
+            self._format = cur.format
+        else:
+            self._format = pq.Format.BINARY if binary else pq.Format.TEXT
+
         # The above result only returned COMMAND_OK. Get the cursor shape
         yield from self._describe_gen(cur)
 
@@ -89,7 +96,7 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
             self.name.encode(pgconn_encoding(conn.pgconn))
         )
         results = yield from execute(conn.pgconn)
-        cur._execute_results(results, format=self.format)
+        cur._execute_results(results, format=self._format)
         self.described = True
 
     def _close_gen(self, cur: BaseCursor[ConnectionType, Row]) -> PQGen[None]:
@@ -133,7 +140,7 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
             howmuch, sql.Identifier(self.name)
         )
         res = yield from cur._conn._exec_command(
-            query, result_format=self.format
+            query, result_format=self._format
         )
 
         cur.pgresult = res
@@ -256,15 +263,11 @@ class ServerCursor(Cursor[Row]):
         """
         if kwargs:
             raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
-        helper = self._helper
-
-        if binary is None:
-            helper.format = self.format
-        else:
-            helper.format = pq.Format.BINARY if binary else pq.Format.TEXT
-
         with self._conn.lock:
-            self._conn.wait(helper._declare_gen(self, query, params))
+            self._conn.wait(
+                self._helper._declare_gen(self, query, params, binary)
+            )
+
         return self
 
     def executemany(self, query: Query, params_seq: Sequence[Params]) -> None:
@@ -376,15 +379,11 @@ class AsyncServerCursor(AsyncCursor[Row]):
     ) -> _AC:
         if kwargs:
             raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
-        helper = self._helper
-
-        if binary is None:
-            helper.format = self.format
-        else:
-            helper.format = pq.Format.BINARY if binary else pq.Format.TEXT
-
         async with self._conn.lock:
-            await self._conn.wait(helper._declare_gen(self, query, params))
+            await self._conn.wait(
+                self._helper._declare_gen(self, query, params, binary)
+            )
+
         return self
 
     async def executemany(