From: Daniele Varrazzo Date: Wed, 11 May 2022 12:38:37 +0000 (+0200) Subject: refactor: implement ServerCursor with a mixin rather than an helper X-Git-Tag: 3.1~107 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3f297f26bb07696b860483e5d533838265f1f72f;p=thirdparty%2Fpsycopg.git refactor: implement ServerCursor with a mixin rather than an helper This was the original intention, but some older Mypy limitation was stopping us from doing so. Or I wasn't good enough at it. --- diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py index d7913090a..f8ebf60ad 100644 --- a/psycopg/psycopg/server_cursor.py +++ b/psycopg/psycopg/server_cursor.py @@ -4,7 +4,7 @@ psycopg server-side cursor objects. # Copyright (C) 2020 The Psycopg Team -from typing import Any, AsyncIterator, Generic, List, Iterable, Iterator +from typing import Any, AsyncIterator, List, Iterable, Iterator from typing import Optional, TypeVar, TYPE_CHECKING, overload from warnings import warn @@ -23,13 +23,10 @@ if TYPE_CHECKING: DEFAULT_ITERSIZE = 100 -class ServerCursorHelper(Generic[ConnectionType, Row]): - __slots__ = ("name", "scrollable", "withhold", "described", "_format") - """Helper object for common ServerCursor code. +class ServerCursorMixin(BaseCursor[ConnectionType, Row]): + """Mixin to add ServerCursor behaviour and implementation a BaseCursor.""" - TODO: this should be a mixin, but couldn't find a way to work it - correctly with the generic. - """ + __slots__ = "_name _scrollable _withhold _described itersize _format".split() def __init__( self, @@ -37,136 +34,151 @@ class ServerCursorHelper(Generic[ConnectionType, Row]): scrollable: Optional[bool], withhold: bool, ): - self.name = name - self.scrollable = scrollable - self.withhold = withhold - self.described = False + self._name = name + self._scrollable = scrollable + self._withhold = withhold + self._described = False + self.itersize: int = DEFAULT_ITERSIZE self._format = pq.Format.TEXT - def _repr(self, cur: BaseCursor[ConnectionType, Row]) -> str: + def __repr__(self) -> str: # Insert the name as the second word - parts = parts = BaseCursor.__repr__(cur).split(None, 1) - parts.insert(1, f"{self.name!r}") + parts = super().__repr__().split(None, 1) + parts.insert(1, f"{self._name!r}") return " ".join(parts) + @property + def name(self) -> str: + """The name of the cursor.""" + return self._name + + @property + def scrollable(self) -> Optional[bool]: + """ + Whether the cursor is scrollable or not. + + If `!None` leave the choice to the server. Use `!True` if you want to + use `scroll()` on the cursor. + """ + return self._scrollable + + @property + def withhold(self) -> bool: + """ + If the cursor can be used after the creating transaction has committed. + """ + return self._withhold + def _declare_gen( self, - cur: BaseCursor[ConnectionType, Row], query: Query, params: Optional[Params] = None, binary: Optional[bool] = None, ) -> PQGen[None]: """Generator implementing `ServerCursor.execute()`.""" - query = self._make_declare_statement(cur, query) + query = self._make_declare_statement(query) # If the cursor is being reused, the previous one must be closed. - if self.described: - yield from self._close_gen(cur) - self.described = False - - yield from cur._start_query(query) - pgq = cur._convert_query(query, params) - cur._execute_send(pgq, no_pqexec=True) - results = yield from execute(cur._conn.pgconn) + if self._described: + yield from self._close_gen() + self._described = False + + yield from self._start_query(query) + pgq = self._convert_query(query, params) + self._execute_send(pgq, no_pqexec=True) + results = yield from execute(self._conn.pgconn) if results[-1].status != pq.ExecStatus.COMMAND_OK: - cur._raise_for_result(results[-1]) + self._raise_for_result(results[-1]) # Set the format, which will be used by describe and fetch operations if binary is None: - self._format = cur.format + self._format = self.format else: self._format = pq.Format.BINARY if binary else pq.Format.TEXT # The above result only returned COMMAND_OK. Get the cursor shape - yield from self._describe_gen(cur) + yield from self._describe_gen() - def _describe_gen(self, cur: BaseCursor[ConnectionType, Row]) -> PQGen[None]: - conn = cur._conn - conn.pgconn.send_describe_portal(self.name.encode(cur._encoding)) - results = yield from execute(conn.pgconn) - cur._check_results(results) - cur._results = results - cur._set_current_result(0, format=self._format) - self.described = True + def _describe_gen(self) -> PQGen[None]: + self._pgconn.send_describe_portal(self._name.encode(self._encoding)) + results = yield from execute(self._pgconn) + self._check_results(results) + self._results = results + self._set_current_result(0, format=self._format) + self._described = True - def _close_gen(self, cur: BaseCursor[ConnectionType, Row]) -> PQGen[None]: - ts = cur._conn.pgconn.transaction_status + def _close_gen(self) -> PQGen[None]: + ts = self._conn.pgconn.transaction_status # if the connection is not in a sane state, don't even try if ts not in (pq.TransactionStatus.IDLE, pq.TransactionStatus.INTRANS): return # If we are IDLE, a WITHOUT HOLD cursor will surely have gone already. - if not self.withhold and ts == pq.TransactionStatus.IDLE: + if not self._withhold and ts == pq.TransactionStatus.IDLE: return # if we didn't declare the cursor ourselves we still have to close it # but we must make sure it exists. - if not self.described: + if not self._described: query = sql.SQL( "SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}" - ).format(sql.Literal(self.name)) - res = yield from cur._conn._exec_command(query) + ).format(sql.Literal(self._name)) + res = yield from self._conn._exec_command(query) # pipeline mode otherwise, unsupported here. assert res is not None if res.ntuples == 0: return - query = sql.SQL("CLOSE {}").format(sql.Identifier(self.name)) - yield from cur._conn._exec_command(query) + query = sql.SQL("CLOSE {}").format(sql.Identifier(self._name)) + yield from self._conn._exec_command(query) - def _fetch_gen( - self, cur: BaseCursor[ConnectionType, Row], num: Optional[int] - ) -> PQGen[List[Row]]: - if cur.closed: + def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Row]]: + if self.closed: raise e.InterfaceError("the cursor is closed") # If we are stealing the cursor, make sure we know its shape - if not self.described: - yield from cur._start_query() - yield from self._describe_gen(cur) + if not self._described: + yield from self._start_query() + yield from self._describe_gen() query = sql.SQL("FETCH FORWARD {} FROM {}").format( sql.SQL("ALL") if num is None else sql.Literal(num), - sql.Identifier(self.name), + sql.Identifier(self._name), ) - res = yield from cur._conn._exec_command(query, result_format=self._format) + res = yield from self._conn._exec_command(query, result_format=self._format) # pipeline mode otherwise, unsupported here. assert res is not None - cur.pgresult = res - cur._tx.set_pgresult(res, set_loaders=False) - return cur._tx.load_rows(0, res.ntuples, cur._make_row) + self.pgresult = res + self._tx.set_pgresult(res, set_loaders=False) + return self._tx.load_rows(0, res.ntuples, self._make_row) - def _scroll_gen( - self, cur: BaseCursor[ConnectionType, Row], value: int, mode: str - ) -> PQGen[None]: + def _scroll_gen(self, value: int, mode: str) -> PQGen[None]: if mode not in ("relative", "absolute"): raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") query = sql.SQL("MOVE{} {} FROM {}").format( sql.SQL(" ABSOLUTE" if mode == "absolute" else ""), sql.Literal(value), - sql.Identifier(self.name), + sql.Identifier(self._name), ) - yield from cur._conn._exec_command(query) + yield from self._conn._exec_command(query) - def _make_declare_statement( - self, cur: BaseCursor[ConnectionType, Row], query: Query - ) -> sql.Composable: + def _make_declare_statement(self, query: Query) -> sql.Composable: if isinstance(query, bytes): - query = query.decode(cur._encoding) + query = query.decode(self._encoding) if not isinstance(query, sql.Composable): query = sql.SQL(query) parts = [ sql.SQL("DECLARE"), - sql.Identifier(self.name), + sql.Identifier(self._name), ] - if self.scrollable is not None: - parts.append(sql.SQL("SCROLL" if self.scrollable else "NO SCROLL")) + if self._scrollable is not None: + parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL")) parts.append(sql.SQL("CURSOR")) - if self.withhold: + if self._withhold: parts.append(sql.SQL("WITH HOLD")) parts.append(sql.SQL("FOR")) parts.append(query) @@ -178,9 +190,9 @@ _C = TypeVar("_C", bound="ServerCursor[Any]") _AC = TypeVar("_AC", bound="AsyncServerCursor[Any]") -class ServerCursor(Cursor[Row]): +class ServerCursor(ServerCursorMixin["Connection[Row]", Row], Cursor[Row]): __module__ = "psycopg" - __slots__ = ("_helper", "itersize") + __slots__ = () @overload def __init__( @@ -214,10 +226,10 @@ class ServerCursor(Cursor[Row]): scrollable: Optional[bool] = None, withhold: bool = False, ): - super().__init__(connection, row_factory=row_factory or connection.row_factory) - self._helper: ServerCursorHelper["Connection[Any]", Row] - self._helper = ServerCursorHelper(name, scrollable, withhold) - self.itersize: int = DEFAULT_ITERSIZE + Cursor.__init__( + self, connection, row_factory=row_factory or connection.row_factory + ) + ServerCursorMixin.__init__(self, name, scrollable, withhold) def __del__(self) -> None: if not self.closed: @@ -227,31 +239,6 @@ class ServerCursor(Cursor[Row]): ResourceWarning, ) - def __repr__(self) -> str: - return self._helper._repr(self) - - @property - def name(self) -> str: - """The name of the cursor.""" - return self._helper.name - - @property - def scrollable(self) -> Optional[bool]: - """ - Whether the cursor is scrollable or not. - - If `!None` leave the choice to the server. Use `!True` if you want to - use `scroll()` on the cursor. - """ - return self._helper.scrollable - - @property - def withhold(self) -> bool: - """ - If the cursor can be used after the creating transaction has committed. - """ - return self._helper.withhold - def close(self) -> None: """ Close the current cursor and free associated resources. @@ -260,7 +247,7 @@ class ServerCursor(Cursor[Row]): if self.closed: return if not self._conn.closed: - self._conn.wait(self._helper._close_gen(self)) + self._conn.wait(self._close_gen()) super().close() def execute( @@ -283,7 +270,7 @@ class ServerCursor(Cursor[Row]): try: with self._conn.lock: - self._conn.wait(self._helper._declare_gen(self, query, params, binary)) + self._conn.wait(self._declare_gen(query, params, binary)) except e.Error as ex: raise ex.with_traceback(None) @@ -301,7 +288,7 @@ class ServerCursor(Cursor[Row]): def fetchone(self) -> Optional[Row]: with self._conn.lock: - recs = self._conn.wait(self._helper._fetch_gen(self, 1)) + recs = self._conn.wait(self._fetch_gen(1)) if recs: self._pos += 1 return recs[0] @@ -312,20 +299,20 @@ class ServerCursor(Cursor[Row]): if not size: size = self.arraysize with self._conn.lock: - recs = self._conn.wait(self._helper._fetch_gen(self, size)) + recs = self._conn.wait(self._fetch_gen(size)) self._pos += len(recs) return recs def fetchall(self) -> List[Row]: with self._conn.lock: - recs = self._conn.wait(self._helper._fetch_gen(self, None)) + recs = self._conn.wait(self._fetch_gen(None)) self._pos += len(recs) return recs def __iter__(self) -> Iterator[Row]: while True: with self._conn.lock: - recs = self._conn.wait(self._helper._fetch_gen(self, self.itersize)) + recs = self._conn.wait(self._fetch_gen(self.itersize)) for rec in recs: self._pos += 1 yield rec @@ -334,7 +321,7 @@ class ServerCursor(Cursor[Row]): def scroll(self, value: int, mode: str = "relative") -> None: with self._conn.lock: - self._conn.wait(self._helper._scroll_gen(self, value, mode)) + self._conn.wait(self._scroll_gen(value, mode)) # Postgres doesn't have a reliable way to report a cursor out of bound if mode == "relative": self._pos += value @@ -342,9 +329,11 @@ class ServerCursor(Cursor[Row]): self._pos = value -class AsyncServerCursor(AsyncCursor[Row]): +class AsyncServerCursor( + ServerCursorMixin["AsyncConnection[Row]", Row], AsyncCursor[Row] +): __module__ = "psycopg" - __slots__ = ("_helper", "itersize") + __slots__ = () @overload def __init__( @@ -378,10 +367,10 @@ class AsyncServerCursor(AsyncCursor[Row]): scrollable: Optional[bool] = None, withhold: bool = False, ): - super().__init__(connection, row_factory=row_factory or connection.row_factory) - self._helper: ServerCursorHelper["AsyncConnection[Any]", Row] - self._helper = ServerCursorHelper(name, scrollable, withhold) - self.itersize: int = DEFAULT_ITERSIZE + AsyncCursor.__init__( + self, connection, row_factory=row_factory or connection.row_factory + ) + ServerCursorMixin.__init__(self, name, scrollable, withhold) def __del__(self) -> None: if not self.closed: @@ -391,27 +380,12 @@ class AsyncServerCursor(AsyncCursor[Row]): ResourceWarning, ) - def __repr__(self) -> str: - return self._helper._repr(self) - - @property - def name(self) -> str: - return self._helper.name - - @property - def scrollable(self) -> Optional[bool]: - return self._helper.scrollable - - @property - def withhold(self) -> bool: - return self._helper.withhold - async def close(self) -> None: async with self._conn.lock: if self.closed: return if not self._conn.closed: - await self._conn.wait(self._helper._close_gen(self)) + await self._conn.wait(self._close_gen()) await super().close() async def execute( @@ -431,9 +405,7 @@ class AsyncServerCursor(AsyncCursor[Row]): try: async with self._conn.lock: - await self._conn.wait( - self._helper._declare_gen(self, query, params, binary) - ) + await self._conn.wait(self._declare_gen(query, params, binary)) except e.Error as ex: raise ex.with_traceback(None) @@ -450,7 +422,7 @@ class AsyncServerCursor(AsyncCursor[Row]): async def fetchone(self) -> Optional[Row]: async with self._conn.lock: - recs = await self._conn.wait(self._helper._fetch_gen(self, 1)) + recs = await self._conn.wait(self._fetch_gen(1)) if recs: self._pos += 1 return recs[0] @@ -461,22 +433,20 @@ class AsyncServerCursor(AsyncCursor[Row]): if not size: size = self.arraysize async with self._conn.lock: - recs = await self._conn.wait(self._helper._fetch_gen(self, size)) + recs = await self._conn.wait(self._fetch_gen(size)) self._pos += len(recs) return recs async def fetchall(self) -> List[Row]: async with self._conn.lock: - recs = await self._conn.wait(self._helper._fetch_gen(self, None)) + recs = await self._conn.wait(self._fetch_gen(None)) self._pos += len(recs) return recs async def __aiter__(self) -> AsyncIterator[Row]: while True: async with self._conn.lock: - recs = await self._conn.wait( - self._helper._fetch_gen(self, self.itersize) - ) + recs = await self._conn.wait(self._fetch_gen(self.itersize)) for rec in recs: self._pos += 1 yield rec @@ -485,4 +455,4 @@ class AsyncServerCursor(AsyncCursor[Row]): async def scroll(self, value: int, mode: str = "relative") -> None: async with self._conn.lock: - await self._conn.wait(self._helper._scroll_gen(self, value, mode)) + await self._conn.wait(self._scroll_gen(value, mode))