From bb93a6d93bc7effcd71e447c1dd97070f5281b2d Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 1 Aug 2021 14:55:48 +0200 Subject: [PATCH] Add AsyncRowFactory class The extra class allow clients to define their RowFactory just taking a Cursor or an AsyncCursor, which is easier if the program handles only one type of connection (sync or async). Using a Generic, the server-side cursors are now subclassed from the respective client-side cursor (sync and async), which allows to drop a bit of implementation duplication. --- docs/advanced/rows.rst | 14 ++++-- psycopg/psycopg/connection.py | 47 +++++++++++------- psycopg/psycopg/cursor.py | 85 ++++++++++++++++++++++---------- psycopg/psycopg/rows.py | 13 ++++- psycopg/psycopg/server_cursor.py | 68 ++++++++++--------------- tests/typing_example.py | 70 ++++++++++++++++++++++++-- 6 files changed, 202 insertions(+), 95 deletions(-) diff --git a/docs/advanced/rows.rst b/docs/advanced/rows.rst index 2d05e0257..3b4c50018 100644 --- a/docs/advanced/rows.rst +++ b/docs/advanced/rows.rst @@ -25,13 +25,21 @@ callable (formally the `~psycopg.rows.RowMaker` protocol) accepting a .. autoclass:: psycopg.rows.RowFactory() - .. method:: __call__(cursor: AnyCursor[Row]) -> RowMaker[Row] + .. method:: __call__(cursor: Cursor[Row]) -> RowMaker[Row] Inspect the result on a cursor and return a `RowMaker` to convert rows. - `!AnyCursor` may be either a `~psycopg.Cursor` or an - `~psycopg.AsyncCursor`. +.. autoclass:: psycopg.rows.AsyncRowFactory() + .. method:: __call__(cursor: AsyncCursor[Row]) -> RowMaker[Row] + + Inspect the result on a cursor and return a `RowMaker` to convert rows. + +Note that it's easy to implement an object implementing both `!RowFactory` and +`!AsyncRowFactory`: usually, everything you need to implement a row factory is +to access `~Cursor.description`, which is provided by both the cursor flavours. +The `psycopg` module also exposes a class `AnyCursor` which you may use if you +want to use the same row factory for both sync and async cursors. `~RowFactory` objects can be implemented as a class, for instance: diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index ee1d280d5..1a32c0561 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -9,7 +9,7 @@ import logging import warnings import threading from types import TracebackType -from typing import Any, AsyncIterator, Callable, Generic, Iterator, List +from typing import Any, AsyncIterator, Callable, cast, Generic, Iterator, List from typing import NamedTuple, Optional, Type, TypeVar, Union from typing import overload, TYPE_CHECKING from weakref import ref, ReferenceType @@ -25,7 +25,7 @@ from . import encodings from .pq import ConnStatus, ExecStatus, TransactionStatus, Format from .abc import ConnectionType, Params, PQGen, PQGenConn, Query, RV from .sql import Composable -from .rows import Row, RowFactory, tuple_row, TupleRow +from .rows import Row, RowFactory, AsyncRowFactory, tuple_row, TupleRow from ._enums import IsolationLevel from .compat import asynccontextmanager from .cursor import Cursor, AsyncCursor @@ -103,7 +103,11 @@ class BaseConnection(Generic[Row]): ConnStatus = pq.ConnStatus TransactionStatus = pq.TransactionStatus - def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]): + def __init__( + self, + pgconn: "PGconn", + row_factory: Union[RowFactory[Row], AsyncRowFactory[Row]], + ): self.pgconn = pgconn # TODO: document this self._row_factory = row_factory self._autocommit = False @@ -297,15 +301,6 @@ class BaseConnection(Generic[Row]): # implement the AdaptContext protocol return self - @property - def row_factory(self) -> RowFactory[Row]: - """Writable attribute to control how result rows are formed.""" - return self._row_factory - - @row_factory.setter - def row_factory(self, row_factory: RowFactory[Row]) -> None: - self._row_factory = row_factory - def fileno(self) -> int: """Return the file descriptor of the connection. @@ -619,6 +614,15 @@ class Connection(BaseConnection[Row]): self._closed = True self.pgconn.finish() + @property + def row_factory(self) -> RowFactory[Row]: + """Writable attribute to control how result rows are formed.""" + return cast(RowFactory[Row], self._row_factory) + + @row_factory.setter + def row_factory(self, row_factory: RowFactory[Row]) -> None: + self._row_factory = row_factory + @overload def cursor(self, *, binary: bool = False) -> Cursor[Row]: ... @@ -785,7 +789,7 @@ class AsyncConnection(BaseConnection[Row]): cursor_factory: Type[AsyncCursor[Row]] server_cursor_factory: Type[AsyncServerCursor[Row]] - def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]): + def __init__(self, pgconn: "PGconn", row_factory: AsyncRowFactory[Row]): super().__init__(pgconn, row_factory) self.lock = asyncio.Lock() self.cursor_factory = AsyncCursor @@ -798,7 +802,7 @@ class AsyncConnection(BaseConnection[Row]): conninfo: str = "", *, autocommit: bool = False, - row_factory: RowFactory[Row], + row_factory: AsyncRowFactory[Row], **kwargs: Union[None, int, str], ) -> "AsyncConnection[Row]": ... @@ -866,13 +870,22 @@ class AsyncConnection(BaseConnection[Row]): self._closed = True self.pgconn.finish() + @property + def row_factory(self) -> AsyncRowFactory[Row]: + """Writable attribute to control how result rows are formed.""" + return cast(AsyncRowFactory[Row], self._row_factory) + + @row_factory.setter + def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None: + self._row_factory = row_factory + @overload def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]: ... @overload def cursor( - self, *, binary: bool = False, row_factory: RowFactory[CursorRow] + self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow] ) -> AsyncCursor[CursorRow]: ... @@ -893,7 +906,7 @@ class AsyncConnection(BaseConnection[Row]): name: str, *, binary: bool = False, - row_factory: RowFactory[CursorRow], + row_factory: AsyncRowFactory[CursorRow], scrollable: Optional[bool] = None, withhold: bool = False, ) -> AsyncServerCursor[CursorRow]: @@ -904,7 +917,7 @@ class AsyncConnection(BaseConnection[Row]): name: str = "", *, binary: bool = False, - row_factory: Optional[RowFactory[Any]] = None, + row_factory: Optional[AsyncRowFactory[Any]] = None, scrollable: Optional[bool] = None, withhold: bool = False, ) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]: diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 860521368..739fba3b3 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -7,7 +7,7 @@ psycopg cursor objects import sys from types import TracebackType from typing import Any, AsyncIterator, Callable, Generic, Iterator, List -from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING +from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING, TypeVar from contextlib import contextmanager from . import pq @@ -18,7 +18,7 @@ from . import generators from .pq import ExecStatus, Format from .abc import ConnectionType, Query, Params, PQGen from .copy import Copy, AsyncCopy -from .rows import Row, RowFactory +from .rows import Row, RowMaker, RowFactory, AsyncRowFactory from .compat import asynccontextmanager from ._column import Column from ._cmodule import _psycopg @@ -53,17 +53,15 @@ class BaseCursor(Generic[ConnectionType, Row]): ExecStatus = pq.ExecStatus _tx: "Transformer" + _make_row: RowMaker[Row] def __init__( self, connection: ConnectionType, - *, - row_factory: RowFactory[Row], ): self._conn = connection self.format = Format.TEXT self._adapters = adapt.AdaptersMap(connection.adapters) - self._row_factory = row_factory self.arraysize = 1 self._closed = False self._last_query: Optional[Query] = None @@ -162,7 +160,7 @@ class BaseCursor(Generic[ConnectionType, Row]): if self._iresult < len(self._results): self.pgresult = self._results[self._iresult] self._tx.set_pgresult(self._results[self._iresult]) - self._make_row = self._row_factory(self) + self._make_row = self._make_row_maker() self._pos = 0 nrows = self.pgresult.command_tuples self._rowcount = nrows if nrows is not None else -1 @@ -170,16 +168,8 @@ class BaseCursor(Generic[ConnectionType, Row]): else: return None - @property - def row_factory(self) -> RowFactory[Row]: - """Writable attribute to control how result rows are formed.""" - return self._row_factory - - @row_factory.setter - def row_factory(self, row_factory: RowFactory[Row]) -> None: - self._row_factory = row_factory - if self.pgresult: - self._make_row = row_factory(self) + def _make_row_maker(self) -> RowMaker[Row]: + raise NotImplementedError # # Generators for the high level operations on the cursor @@ -276,7 +266,7 @@ class BaseCursor(Generic[ConnectionType, Row]): self.pgresult = res self._tx.set_pgresult(res, set_loaders=first) if first: - self._make_row = self._row_factory(self) + self._make_row = self._make_row_maker() return res elif res.status in (ExecStatus.TUPLES_OK, ExecStatus.COMMAND_OK): @@ -379,7 +369,7 @@ class BaseCursor(Generic[ConnectionType, Row]): self._results = list(results) self.pgresult = results[0] self._tx.set_pgresult(results[0]) - self._make_row = self._row_factory(self) + self._make_row = self._make_row_maker() nrows = self.pgresult.command_tuples if nrows is not None: if self._rowcount < 0: @@ -387,8 +377,6 @@ class BaseCursor(Generic[ConnectionType, Row]): else: self._rowcount += nrows - return - def _raise_from_results(self, results: Sequence["PGresult"]) -> NoReturn: statuses = {res.status for res in results} badstats = statuses.difference(self._status_ok) @@ -467,11 +455,20 @@ class BaseCursor(Generic[ConnectionType, Row]): AnyCursor = BaseCursor[Any, Row] +C = TypeVar("C", bound="BaseCursor[Any, Any]") + + class Cursor(BaseCursor["Connection[Any]", Row]): __module__ = "psycopg" __slots__ = () - def __enter__(self) -> "Cursor[Row]": + def __init__( + self, connection: "Connection[Any]", *, row_factory: RowFactory[Row] + ): + super().__init__(connection) + self._row_factory = row_factory + + def __enter__(self: C) -> C: return self def __exit__( @@ -488,13 +485,27 @@ class Cursor(BaseCursor["Connection[Any]", Row]): """ self._close() + @property + def row_factory(self) -> RowFactory[Row]: + """Writable attribute to control how result rows are formed.""" + return self._row_factory + + @row_factory.setter + def row_factory(self, row_factory: RowFactory[Row]) -> None: + self._row_factory = row_factory + if self.pgresult: + self._make_row = row_factory(self) + + def _make_row_maker(self) -> RowMaker[Row]: + return self._row_factory(self) + def execute( - self, + self: C, query: Query, params: Optional[Params] = None, *, prepare: Optional[bool] = None, - ) -> "Cursor[Row]": + ) -> C: """ Execute a query or command to the database. """ @@ -622,7 +633,16 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): __module__ = "psycopg" __slots__ = () - async def __aenter__(self) -> "AsyncCursor[Row]": + def __init__( + self, + connection: "AsyncConnection[Any]", + *, + row_factory: AsyncRowFactory[Row], + ): + super().__init__(connection) + self._row_factory = row_factory + + async def __aenter__(self: C) -> C: return self async def __aexit__( @@ -636,13 +656,26 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): async def close(self) -> None: self._close() + @property + def row_factory(self) -> AsyncRowFactory[Row]: + return self._row_factory + + @row_factory.setter + def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None: + self._row_factory = row_factory + if self.pgresult: + self._make_row = row_factory(self) + + def _make_row_maker(self) -> RowMaker[Row]: + return self._row_factory(self) + async def execute( - self, + self: C, query: Query, params: Optional[Params] = None, *, prepare: Optional[bool] = None, - ) -> "AsyncCursor[Row]": + ) -> C: try: async with self._conn.lock: await self._conn.wait( diff --git a/psycopg/psycopg/rows.py b/psycopg/psycopg/rows.py index 1ffa9893d..5e4927e7f 100644 --- a/psycopg/psycopg/rows.py +++ b/psycopg/psycopg/rows.py @@ -14,7 +14,7 @@ from . import errors as e from .compat import Protocol if TYPE_CHECKING: - from .cursor import AnyCursor + from .cursor import AnyCursor, Cursor, AsyncCursor # Row factories @@ -52,7 +52,16 @@ class RowFactory(Protocol[Row]): use the values to create a dictionary for each record. """ - def __call__(self, __cursor: "AnyCursor[Row]") -> RowMaker[Row]: + def __call__(self, __cursor: "Cursor[Row]") -> RowMaker[Row]: + ... + + +class AsyncRowFactory(Protocol[Row]): + """ + Callable protocol taking an `~psycopg.AsyncCursor` and returning a `RowMaker`. + """ + + def __call__(self, __cursor: "AsyncCursor[Row]") -> RowMaker[Row]: ... diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py index ec1859330..ae58261ed 100644 --- a/psycopg/psycopg/server_cursor.py +++ b/psycopg/psycopg/server_cursor.py @@ -5,19 +5,17 @@ psycopg server-side cursor objects. # Copyright (C) 2020-2021 The Psycopg Team import warnings -from types import TracebackType -from typing import AsyncIterator, Generic, List, Iterator, Optional -from typing import Sequence, Type, TYPE_CHECKING +from typing import Any, AsyncIterator, cast, Generic, List, Iterator, Optional +from typing import Sequence, TYPE_CHECKING from . import pq from . import sql from . import errors as e from .abc import ConnectionType, Query, Params, PQGen -from .rows import Row, RowFactory -from .cursor import BaseCursor, execute +from .rows import Row, RowFactory, AsyncRowFactory +from .cursor import C, BaseCursor, Cursor, AsyncCursor, execute if TYPE_CHECKING: - from typing import Any # noqa: F401 from .connection import BaseConnection # noqa: F401 from .connection import Connection, AsyncConnection # noqa: F401 @@ -175,7 +173,7 @@ class ServerCursorHelper(Generic[ConnectionType, Row]): return sql.SQL(" ").join(parts) -class ServerCursor(BaseCursor["Connection[Any]", Row]): +class ServerCursor(Cursor[Row]): __module__ = "psycopg" __slots__ = ("_helper", "itersize") @@ -204,17 +202,6 @@ class ServerCursor(BaseCursor["Connection[Any]", Row]): def __repr__(self) -> str: return self._helper._repr(self) - def __enter__(self) -> "ServerCursor[Row]": - return self - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - self.close() - @property def name(self) -> str: """The name of the cursor.""" @@ -245,19 +232,23 @@ class ServerCursor(BaseCursor["Connection[Any]", Row]): if self.closed: return self._conn.wait(self._helper._close_gen(self)) - self._close() + super().close() def execute( - self, + self: C, query: Query, params: Optional[Params] = None, - ) -> "ServerCursor[Row]": + **kwargs: Any, + ) -> C: """ Open a cursor to execute a query to the database. """ - query = self._helper._make_declare_statement(self, query) + if kwargs: + raise TypeError(f"keyword not supported: {list(kwargs)[0]}") + helper = cast(ServerCursor[Row], self)._helper + query = helper._make_declare_statement(self, query) with self._conn.lock: - self._conn.wait(self._helper._declare_gen(self, query, params)) + self._conn.wait(helper._declare_gen(self, query, params)) return self def executemany(self, query: Query, params_seq: Sequence[Params]) -> None: @@ -311,7 +302,7 @@ class ServerCursor(BaseCursor["Connection[Any]", Row]): self._pos = value -class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]): +class AsyncServerCursor(AsyncCursor[Row]): __module__ = "psycopg" __slots__ = ("_helper", "itersize") @@ -320,7 +311,7 @@ class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]): connection: "AsyncConnection[Any]", name: str, *, - row_factory: RowFactory[Row], + row_factory: AsyncRowFactory[Row], scrollable: Optional[bool] = None, withhold: bool = False, ): @@ -340,17 +331,6 @@ class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]): def __repr__(self) -> str: return self._helper._repr(self) - async def __aenter__(self) -> "AsyncServerCursor[Row]": - return self - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - await self.close() - @property def name(self) -> str: return self._helper.name @@ -368,18 +348,20 @@ class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]): if self.closed: return await self._conn.wait(self._helper._close_gen(self)) - self._close() + await super().close() async def execute( - self, + self: C, query: Query, params: Optional[Params] = None, - ) -> "AsyncServerCursor[Row]": - query = self._helper._make_declare_statement(self, query) + **kwargs: Any, + ) -> C: + if kwargs: + raise TypeError(f"keyword not supported: {list(kwargs)[0]}") + helper = cast(AsyncServerCursor[Row], self)._helper + query = helper._make_declare_statement(self, query) async with self._conn.lock: - await self._conn.wait( - self._helper._declare_gen(self, query, params) - ) + await self._conn.wait(helper._declare_gen(self, query, params)) return self async def executemany( diff --git a/tests/typing_example.py b/tests/typing_example.py index b0d4c3e61..aeb7a7a6b 100644 --- a/tests/typing_example.py +++ b/tests/typing_example.py @@ -3,12 +3,15 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable, Optional, Sequence, Tuple +from typing import Any, Callable, Optional, Sequence, Tuple, Union -from psycopg import AnyCursor, Connection, Cursor, ServerCursor, connect +from psycopg import Connection, Cursor, ServerCursor, connect +from psycopg import AsyncConnection, AsyncCursor, AsyncServerCursor -def int_row_factory(cursor: AnyCursor[int]) -> Callable[[Sequence[int]], int]: +def int_row_factory( + cursor: Union[Cursor[int], AsyncCursor[int]] +) -> Callable[[Sequence[int]], int]: return lambda values: values[0] if values else 42 @@ -19,7 +22,7 @@ class Person: @classmethod def row_factory( - cls, cursor: AnyCursor[Person] + cls, cursor: Union[Cursor[Person], AsyncCursor[Person]] ) -> Callable[[Sequence[str]], Person]: def mkrow(values: Sequence[str]) -> Person: name, address = values @@ -53,6 +56,31 @@ def check_row_factory_cursor() -> None: persons[0].address +async def async_check_row_factory_cursor() -> None: + """Type-check connection.cursor(..., row_factory=) case.""" + conn = await AsyncConnection.connect() + + cur1: AsyncCursor[Any] + cur1 = conn.cursor() + r1: Optional[Any] + r1 = await cur1.fetchone() + r1 is not None + + cur2: AsyncCursor[int] + r2: Optional[int] + async with conn.cursor(row_factory=int_row_factory) as cur2: + await cur2.execute("select 1") + r2 = await cur2.fetchone() + r2 and r2 > 0 + + cur3: AsyncServerCursor[Person] + persons: Sequence[Person] + async with conn.cursor(name="s", row_factory=Person.row_factory) as cur3: + await cur3.execute("select * from persons where name like 'al%'") + persons = await cur3.fetchall() + persons[0].address + + def check_row_factory_connection() -> None: """Type-check connect(..., row_factory=) or Connection.row_factory cases. @@ -85,3 +113,37 @@ def check_row_factory_connection() -> None: cur3.execute("select 42") r3 = cur3.fetchone() r3 and len(r3) + + +async def async_check_row_factory_connection() -> None: + """Type-check connect(..., row_factory=) or + Connection.row_factory cases. + """ + conn1: AsyncConnection[int] + cur1: AsyncCursor[int] + r1: Optional[int] + conn1 = await AsyncConnection.connect(row_factory=int_row_factory) + cur1 = await conn1.execute("select 1") + r1 = await cur1.fetchone() + r1 != 0 + async with conn1.cursor() as cur1: + await cur1.execute("select 2") + + conn2: AsyncConnection[Person] + cur2: AsyncCursor[Person] + r2: Optional[Person] + conn2 = await AsyncConnection.connect(row_factory=Person.row_factory) + cur2 = await conn2.execute("select * from persons") + r2 = await cur2.fetchone() + r2 and r2.name + async with conn2.cursor() as cur2: + await cur2.execute("select 2") + + cur3: AsyncCursor[Tuple[Any, ...]] + r3: Optional[Tuple[Any, ...]] + conn3 = await AsyncConnection.connect() + cur3 = await conn3.execute("select 3") + async with conn3.cursor() as cur3: + await cur3.execute("select 42") + r3 = await cur3.fetchone() + r3 and len(r3) -- 2.47.3