]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: implement ServerCursor with a mixin rather than an helper
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 May 2022 12:38:37 +0000 (14:38 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 12 May 2022 13:36:11 +0000 (15:36 +0200)
This was the original intention, but some older Mypy limitation was
stopping us from doing so. Or I wasn't good enough at it.

psycopg/psycopg/server_cursor.py

index d7913090a34dd0aff52f9ae00f24f7478b32e05b..f8ebf60ad4f4036f6c4f2593e363c67da59470a2 100644 (file)
@@ -4,7 +4,7 @@ psycopg server-side cursor objects.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, AsyncIterator, Generic, List, Iterable, Iterator
+from typing import Any, AsyncIterator, List, Iterable, Iterator
 from typing import Optional, TypeVar, TYPE_CHECKING, overload
 from warnings import warn
 
@@ -23,13 +23,10 @@ if TYPE_CHECKING:
 DEFAULT_ITERSIZE = 100
 
 
-class ServerCursorHelper(Generic[ConnectionType, Row]):
-    __slots__ = ("name", "scrollable", "withhold", "described", "_format")
-    """Helper object for common ServerCursor code.
+class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
+    """Mixin to add ServerCursor behaviour and implementation a BaseCursor."""
 
-    TODO: this should be a mixin, but couldn't find a way to work it
-    correctly with the generic.
-    """
+    __slots__ = "_name _scrollable _withhold _described itersize _format".split()
 
     def __init__(
         self,
@@ -37,136 +34,151 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         scrollable: Optional[bool],
         withhold: bool,
     ):
-        self.name = name
-        self.scrollable = scrollable
-        self.withhold = withhold
-        self.described = False
+        self._name = name
+        self._scrollable = scrollable
+        self._withhold = withhold
+        self._described = False
+        self.itersize: int = DEFAULT_ITERSIZE
         self._format = pq.Format.TEXT
 
-    def _repr(self, cur: BaseCursor[ConnectionType, Row]) -> str:
+    def __repr__(self) -> str:
         # Insert the name as the second word
-        parts = parts = BaseCursor.__repr__(cur).split(None, 1)
-        parts.insert(1, f"{self.name!r}")
+        parts = super().__repr__().split(None, 1)
+        parts.insert(1, f"{self._name!r}")
         return " ".join(parts)
 
+    @property
+    def name(self) -> str:
+        """The name of the cursor."""
+        return self._name
+
+    @property
+    def scrollable(self) -> Optional[bool]:
+        """
+        Whether the cursor is scrollable or not.
+
+        If `!None` leave the choice to the server. Use `!True` if you want to
+        use `scroll()` on the cursor.
+        """
+        return self._scrollable
+
+    @property
+    def withhold(self) -> bool:
+        """
+        If the cursor can be used after the creating transaction has committed.
+        """
+        return self._withhold
+
     def _declare_gen(
         self,
-        cur: BaseCursor[ConnectionType, Row],
         query: Query,
         params: Optional[Params] = None,
         binary: Optional[bool] = None,
     ) -> PQGen[None]:
         """Generator implementing `ServerCursor.execute()`."""
 
-        query = self._make_declare_statement(cur, query)
+        query = self._make_declare_statement(query)
 
         # If the cursor is being reused, the previous one must be closed.
-        if self.described:
-            yield from self._close_gen(cur)
-            self.described = False
-
-        yield from cur._start_query(query)
-        pgq = cur._convert_query(query, params)
-        cur._execute_send(pgq, no_pqexec=True)
-        results = yield from execute(cur._conn.pgconn)
+        if self._described:
+            yield from self._close_gen()
+            self._described = False
+
+        yield from self._start_query(query)
+        pgq = self._convert_query(query, params)
+        self._execute_send(pgq, no_pqexec=True)
+        results = yield from execute(self._conn.pgconn)
         if results[-1].status != pq.ExecStatus.COMMAND_OK:
-            cur._raise_for_result(results[-1])
+            self._raise_for_result(results[-1])
 
         # Set the format, which will be used by describe and fetch operations
         if binary is None:
-            self._format = cur.format
+            self._format = self.format
         else:
             self._format = pq.Format.BINARY if binary else pq.Format.TEXT
 
         # The above result only returned COMMAND_OK. Get the cursor shape
-        yield from self._describe_gen(cur)
+        yield from self._describe_gen()
 
-    def _describe_gen(self, cur: BaseCursor[ConnectionType, Row]) -> PQGen[None]:
-        conn = cur._conn
-        conn.pgconn.send_describe_portal(self.name.encode(cur._encoding))
-        results = yield from execute(conn.pgconn)
-        cur._check_results(results)
-        cur._results = results
-        cur._set_current_result(0, format=self._format)
-        self.described = True
+    def _describe_gen(self) -> PQGen[None]:
+        self._pgconn.send_describe_portal(self._name.encode(self._encoding))
+        results = yield from execute(self._pgconn)
+        self._check_results(results)
+        self._results = results
+        self._set_current_result(0, format=self._format)
+        self._described = True
 
-    def _close_gen(self, cur: BaseCursor[ConnectionType, Row]) -> PQGen[None]:
-        ts = cur._conn.pgconn.transaction_status
+    def _close_gen(self) -> PQGen[None]:
+        ts = self._conn.pgconn.transaction_status
 
         # if the connection is not in a sane state, don't even try
         if ts not in (pq.TransactionStatus.IDLE, pq.TransactionStatus.INTRANS):
             return
 
         # If we are IDLE, a WITHOUT HOLD cursor will surely have gone already.
-        if not self.withhold and ts == pq.TransactionStatus.IDLE:
+        if not self._withhold and ts == pq.TransactionStatus.IDLE:
             return
 
         # if we didn't declare the cursor ourselves we still have to close it
         # but we must make sure it exists.
-        if not self.described:
+        if not self._described:
             query = sql.SQL(
                 "SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}"
-            ).format(sql.Literal(self.name))
-            res = yield from cur._conn._exec_command(query)
+            ).format(sql.Literal(self._name))
+            res = yield from self._conn._exec_command(query)
             # pipeline mode otherwise, unsupported here.
             assert res is not None
             if res.ntuples == 0:
                 return
 
-        query = sql.SQL("CLOSE {}").format(sql.Identifier(self.name))
-        yield from cur._conn._exec_command(query)
+        query = sql.SQL("CLOSE {}").format(sql.Identifier(self._name))
+        yield from self._conn._exec_command(query)
 
-    def _fetch_gen(
-        self, cur: BaseCursor[ConnectionType, Row], num: Optional[int]
-    ) -> PQGen[List[Row]]:
-        if cur.closed:
+    def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Row]]:
+        if self.closed:
             raise e.InterfaceError("the cursor is closed")
         # If we are stealing the cursor, make sure we know its shape
-        if not self.described:
-            yield from cur._start_query()
-            yield from self._describe_gen(cur)
+        if not self._described:
+            yield from self._start_query()
+            yield from self._describe_gen()
 
         query = sql.SQL("FETCH FORWARD {} FROM {}").format(
             sql.SQL("ALL") if num is None else sql.Literal(num),
-            sql.Identifier(self.name),
+            sql.Identifier(self._name),
         )
-        res = yield from cur._conn._exec_command(query, result_format=self._format)
+        res = yield from self._conn._exec_command(query, result_format=self._format)
         # pipeline mode otherwise, unsupported here.
         assert res is not None
 
-        cur.pgresult = res
-        cur._tx.set_pgresult(res, set_loaders=False)
-        return cur._tx.load_rows(0, res.ntuples, cur._make_row)
+        self.pgresult = res
+        self._tx.set_pgresult(res, set_loaders=False)
+        return self._tx.load_rows(0, res.ntuples, self._make_row)
 
-    def _scroll_gen(
-        self, cur: BaseCursor[ConnectionType, Row], value: int, mode: str
-    ) -> PQGen[None]:
+    def _scroll_gen(self, value: int, mode: str) -> PQGen[None]:
         if mode not in ("relative", "absolute"):
             raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
         query = sql.SQL("MOVE{} {} FROM {}").format(
             sql.SQL(" ABSOLUTE" if mode == "absolute" else ""),
             sql.Literal(value),
-            sql.Identifier(self.name),
+            sql.Identifier(self._name),
         )
-        yield from cur._conn._exec_command(query)
+        yield from self._conn._exec_command(query)
 
-    def _make_declare_statement(
-        self, cur: BaseCursor[ConnectionType, Row], query: Query
-    ) -> sql.Composable:
+    def _make_declare_statement(self, query: Query) -> sql.Composable:
 
         if isinstance(query, bytes):
-            query = query.decode(cur._encoding)
+            query = query.decode(self._encoding)
         if not isinstance(query, sql.Composable):
             query = sql.SQL(query)
 
         parts = [
             sql.SQL("DECLARE"),
-            sql.Identifier(self.name),
+            sql.Identifier(self._name),
         ]
-        if self.scrollable is not None:
-            parts.append(sql.SQL("SCROLL" if self.scrollable else "NO SCROLL"))
+        if self._scrollable is not None:
+            parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL"))
         parts.append(sql.SQL("CURSOR"))
-        if self.withhold:
+        if self._withhold:
             parts.append(sql.SQL("WITH HOLD"))
         parts.append(sql.SQL("FOR"))
         parts.append(query)
@@ -178,9 +190,9 @@ _C = TypeVar("_C", bound="ServerCursor[Any]")
 _AC = TypeVar("_AC", bound="AsyncServerCursor[Any]")
 
 
-class ServerCursor(Cursor[Row]):
+class ServerCursor(ServerCursorMixin["Connection[Row]", Row], Cursor[Row]):
     __module__ = "psycopg"
-    __slots__ = ("_helper", "itersize")
+    __slots__ = ()
 
     @overload
     def __init__(
@@ -214,10 +226,10 @@ class ServerCursor(Cursor[Row]):
         scrollable: Optional[bool] = None,
         withhold: bool = False,
     ):
-        super().__init__(connection, row_factory=row_factory or connection.row_factory)
-        self._helper: ServerCursorHelper["Connection[Any]", Row]
-        self._helper = ServerCursorHelper(name, scrollable, withhold)
-        self.itersize: int = DEFAULT_ITERSIZE
+        Cursor.__init__(
+            self, connection, row_factory=row_factory or connection.row_factory
+        )
+        ServerCursorMixin.__init__(self, name, scrollable, withhold)
 
     def __del__(self) -> None:
         if not self.closed:
@@ -227,31 +239,6 @@ class ServerCursor(Cursor[Row]):
                 ResourceWarning,
             )
 
-    def __repr__(self) -> str:
-        return self._helper._repr(self)
-
-    @property
-    def name(self) -> str:
-        """The name of the cursor."""
-        return self._helper.name
-
-    @property
-    def scrollable(self) -> Optional[bool]:
-        """
-        Whether the cursor is scrollable or not.
-
-        If `!None` leave the choice to the server. Use `!True` if you want to
-        use `scroll()` on the cursor.
-        """
-        return self._helper.scrollable
-
-    @property
-    def withhold(self) -> bool:
-        """
-        If the cursor can be used after the creating transaction has committed.
-        """
-        return self._helper.withhold
-
     def close(self) -> None:
         """
         Close the current cursor and free associated resources.
@@ -260,7 +247,7 @@ class ServerCursor(Cursor[Row]):
             if self.closed:
                 return
             if not self._conn.closed:
-                self._conn.wait(self._helper._close_gen(self))
+                self._conn.wait(self._close_gen())
             super().close()
 
     def execute(
@@ -283,7 +270,7 @@ class ServerCursor(Cursor[Row]):
 
         try:
             with self._conn.lock:
-                self._conn.wait(self._helper._declare_gen(self, query, params, binary))
+                self._conn.wait(self._declare_gen(query, params, binary))
         except e.Error as ex:
             raise ex.with_traceback(None)
 
@@ -301,7 +288,7 @@ class ServerCursor(Cursor[Row]):
 
     def fetchone(self) -> Optional[Row]:
         with self._conn.lock:
-            recs = self._conn.wait(self._helper._fetch_gen(self, 1))
+            recs = self._conn.wait(self._fetch_gen(1))
         if recs:
             self._pos += 1
             return recs[0]
@@ -312,20 +299,20 @@ class ServerCursor(Cursor[Row]):
         if not size:
             size = self.arraysize
         with self._conn.lock:
-            recs = self._conn.wait(self._helper._fetch_gen(self, size))
+            recs = self._conn.wait(self._fetch_gen(size))
         self._pos += len(recs)
         return recs
 
     def fetchall(self) -> List[Row]:
         with self._conn.lock:
-            recs = self._conn.wait(self._helper._fetch_gen(self, None))
+            recs = self._conn.wait(self._fetch_gen(None))
         self._pos += len(recs)
         return recs
 
     def __iter__(self) -> Iterator[Row]:
         while True:
             with self._conn.lock:
-                recs = self._conn.wait(self._helper._fetch_gen(self, self.itersize))
+                recs = self._conn.wait(self._fetch_gen(self.itersize))
             for rec in recs:
                 self._pos += 1
                 yield rec
@@ -334,7 +321,7 @@ class ServerCursor(Cursor[Row]):
 
     def scroll(self, value: int, mode: str = "relative") -> None:
         with self._conn.lock:
-            self._conn.wait(self._helper._scroll_gen(self, value, mode))
+            self._conn.wait(self._scroll_gen(value, mode))
         # Postgres doesn't have a reliable way to report a cursor out of bound
         if mode == "relative":
             self._pos += value
@@ -342,9 +329,11 @@ class ServerCursor(Cursor[Row]):
             self._pos = value
 
 
-class AsyncServerCursor(AsyncCursor[Row]):
+class AsyncServerCursor(
+    ServerCursorMixin["AsyncConnection[Row]", Row], AsyncCursor[Row]
+):
     __module__ = "psycopg"
-    __slots__ = ("_helper", "itersize")
+    __slots__ = ()
 
     @overload
     def __init__(
@@ -378,10 +367,10 @@ class AsyncServerCursor(AsyncCursor[Row]):
         scrollable: Optional[bool] = None,
         withhold: bool = False,
     ):
-        super().__init__(connection, row_factory=row_factory or connection.row_factory)
-        self._helper: ServerCursorHelper["AsyncConnection[Any]", Row]
-        self._helper = ServerCursorHelper(name, scrollable, withhold)
-        self.itersize: int = DEFAULT_ITERSIZE
+        AsyncCursor.__init__(
+            self, connection, row_factory=row_factory or connection.row_factory
+        )
+        ServerCursorMixin.__init__(self, name, scrollable, withhold)
 
     def __del__(self) -> None:
         if not self.closed:
@@ -391,27 +380,12 @@ class AsyncServerCursor(AsyncCursor[Row]):
                 ResourceWarning,
             )
 
-    def __repr__(self) -> str:
-        return self._helper._repr(self)
-
-    @property
-    def name(self) -> str:
-        return self._helper.name
-
-    @property
-    def scrollable(self) -> Optional[bool]:
-        return self._helper.scrollable
-
-    @property
-    def withhold(self) -> bool:
-        return self._helper.withhold
-
     async def close(self) -> None:
         async with self._conn.lock:
             if self.closed:
                 return
             if not self._conn.closed:
-                await self._conn.wait(self._helper._close_gen(self))
+                await self._conn.wait(self._close_gen())
             await super().close()
 
     async def execute(
@@ -431,9 +405,7 @@ class AsyncServerCursor(AsyncCursor[Row]):
 
         try:
             async with self._conn.lock:
-                await self._conn.wait(
-                    self._helper._declare_gen(self, query, params, binary)
-                )
+                await self._conn.wait(self._declare_gen(query, params, binary))
         except e.Error as ex:
             raise ex.with_traceback(None)
 
@@ -450,7 +422,7 @@ class AsyncServerCursor(AsyncCursor[Row]):
 
     async def fetchone(self) -> Optional[Row]:
         async with self._conn.lock:
-            recs = await self._conn.wait(self._helper._fetch_gen(self, 1))
+            recs = await self._conn.wait(self._fetch_gen(1))
         if recs:
             self._pos += 1
             return recs[0]
@@ -461,22 +433,20 @@ class AsyncServerCursor(AsyncCursor[Row]):
         if not size:
             size = self.arraysize
         async with self._conn.lock:
-            recs = await self._conn.wait(self._helper._fetch_gen(self, size))
+            recs = await self._conn.wait(self._fetch_gen(size))
         self._pos += len(recs)
         return recs
 
     async def fetchall(self) -> List[Row]:
         async with self._conn.lock:
-            recs = await self._conn.wait(self._helper._fetch_gen(self, None))
+            recs = await self._conn.wait(self._fetch_gen(None))
         self._pos += len(recs)
         return recs
 
     async def __aiter__(self) -> AsyncIterator[Row]:
         while True:
             async with self._conn.lock:
-                recs = await self._conn.wait(
-                    self._helper._fetch_gen(self, self.itersize)
-                )
+                recs = await self._conn.wait(self._fetch_gen(self.itersize))
             for rec in recs:
                 self._pos += 1
                 yield rec
@@ -485,4 +455,4 @@ class AsyncServerCursor(AsyncCursor[Row]):
 
     async def scroll(self, value: int, mode: str = "relative") -> None:
         async with self._conn.lock:
-            await self._conn.wait(self._helper._scroll_gen(self, value, mode))
+            await self._conn.wait(self._scroll_gen(value, mode))