]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add AsyncRowFactory class
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 1 Aug 2021 12:55:48 +0000 (14:55 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 1 Aug 2021 20:36:23 +0000 (22:36 +0200)
The extra class allow clients to define their RowFactory just taking a
Cursor or an AsyncCursor, which is easier if the program handles only
one type of connection (sync or async).

Using a Generic, the server-side cursors are now subclassed from  the
respective client-side cursor (sync and async), which allows to drop a
bit of implementation duplication.

docs/advanced/rows.rst
psycopg/psycopg/connection.py
psycopg/psycopg/cursor.py
psycopg/psycopg/rows.py
psycopg/psycopg/server_cursor.py
tests/typing_example.py

index 2d05e0257df8446c8f2d1aa18049d477a4dffc73..3b4c500185508daf58ee79597122061f55bb8a09 100644 (file)
@@ -25,13 +25,21 @@ callable (formally the `~psycopg.rows.RowMaker` protocol) accepting a
 
 .. autoclass:: psycopg.rows.RowFactory()
 
-   .. method:: __call__(cursor: AnyCursor[Row]) -> RowMaker[Row]
+   .. method:: __call__(cursor: Cursor[Row]) -> RowMaker[Row]
 
         Inspect the result on a cursor and return a `RowMaker` to convert rows.
 
-        `!AnyCursor` may be either a `~psycopg.Cursor` or an
-        `~psycopg.AsyncCursor`.
+.. autoclass:: psycopg.rows.AsyncRowFactory()
 
+   .. method:: __call__(cursor: AsyncCursor[Row]) -> RowMaker[Row]
+
+        Inspect the result on a cursor and return a `RowMaker` to convert rows.
+
+Note that it's easy to implement an object implementing both `!RowFactory` and
+`!AsyncRowFactory`: usually, everything you need to implement a row factory is
+to access `~Cursor.description`, which is provided by both the cursor flavours.
+The `psycopg` module also exposes a class `AnyCursor` which you may use if you
+want to use the same row factory for both sync and async cursors.
 
 `~RowFactory` objects can be implemented as a class, for instance:
 
index ee1d280d5bd73a44a21e150cd869ce2bdb72916f..1a32c0561a866793e9ed9ea28df9578e99bd64d7 100644 (file)
@@ -9,7 +9,7 @@ import logging
 import warnings
 import threading
 from types import TracebackType
-from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
+from typing import Any, AsyncIterator, Callable, cast, Generic, Iterator, List
 from typing import NamedTuple, Optional, Type, TypeVar, Union
 from typing import overload, TYPE_CHECKING
 from weakref import ref, ReferenceType
@@ -25,7 +25,7 @@ from . import encodings
 from .pq import ConnStatus, ExecStatus, TransactionStatus, Format
 from .abc import ConnectionType, Params, PQGen, PQGenConn, Query, RV
 from .sql import Composable
-from .rows import Row, RowFactory, tuple_row, TupleRow
+from .rows import Row, RowFactory, AsyncRowFactory, tuple_row, TupleRow
 from ._enums import IsolationLevel
 from .compat import asynccontextmanager
 from .cursor import Cursor, AsyncCursor
@@ -103,7 +103,11 @@ class BaseConnection(Generic[Row]):
     ConnStatus = pq.ConnStatus
     TransactionStatus = pq.TransactionStatus
 
-    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]):
+    def __init__(
+        self,
+        pgconn: "PGconn",
+        row_factory: Union[RowFactory[Row], AsyncRowFactory[Row]],
+    ):
         self.pgconn = pgconn  # TODO: document this
         self._row_factory = row_factory
         self._autocommit = False
@@ -297,15 +301,6 @@ class BaseConnection(Generic[Row]):
         # implement the AdaptContext protocol
         return self
 
-    @property
-    def row_factory(self) -> RowFactory[Row]:
-        """Writable attribute to control how result rows are formed."""
-        return self._row_factory
-
-    @row_factory.setter
-    def row_factory(self, row_factory: RowFactory[Row]) -> None:
-        self._row_factory = row_factory
-
     def fileno(self) -> int:
         """Return the file descriptor of the connection.
 
@@ -619,6 +614,15 @@ class Connection(BaseConnection[Row]):
         self._closed = True
         self.pgconn.finish()
 
+    @property
+    def row_factory(self) -> RowFactory[Row]:
+        """Writable attribute to control how result rows are formed."""
+        return cast(RowFactory[Row], self._row_factory)
+
+    @row_factory.setter
+    def row_factory(self, row_factory: RowFactory[Row]) -> None:
+        self._row_factory = row_factory
+
     @overload
     def cursor(self, *, binary: bool = False) -> Cursor[Row]:
         ...
@@ -785,7 +789,7 @@ class AsyncConnection(BaseConnection[Row]):
     cursor_factory: Type[AsyncCursor[Row]]
     server_cursor_factory: Type[AsyncServerCursor[Row]]
 
-    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]):
+    def __init__(self, pgconn: "PGconn", row_factory: AsyncRowFactory[Row]):
         super().__init__(pgconn, row_factory)
         self.lock = asyncio.Lock()
         self.cursor_factory = AsyncCursor
@@ -798,7 +802,7 @@ class AsyncConnection(BaseConnection[Row]):
         conninfo: str = "",
         *,
         autocommit: bool = False,
-        row_factory: RowFactory[Row],
+        row_factory: AsyncRowFactory[Row],
         **kwargs: Union[None, int, str],
     ) -> "AsyncConnection[Row]":
         ...
@@ -866,13 +870,22 @@ class AsyncConnection(BaseConnection[Row]):
         self._closed = True
         self.pgconn.finish()
 
+    @property
+    def row_factory(self) -> AsyncRowFactory[Row]:
+        """Writable attribute to control how result rows are formed."""
+        return cast(AsyncRowFactory[Row], self._row_factory)
+
+    @row_factory.setter
+    def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None:
+        self._row_factory = row_factory
+
     @overload
     def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]:
         ...
 
     @overload
     def cursor(
-        self, *, binary: bool = False, row_factory: RowFactory[CursorRow]
+        self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow]
     ) -> AsyncCursor[CursorRow]:
         ...
 
@@ -893,7 +906,7 @@ class AsyncConnection(BaseConnection[Row]):
         name: str,
         *,
         binary: bool = False,
-        row_factory: RowFactory[CursorRow],
+        row_factory: AsyncRowFactory[CursorRow],
         scrollable: Optional[bool] = None,
         withhold: bool = False,
     ) -> AsyncServerCursor[CursorRow]:
@@ -904,7 +917,7 @@ class AsyncConnection(BaseConnection[Row]):
         name: str = "",
         *,
         binary: bool = False,
-        row_factory: Optional[RowFactory[Any]] = None,
+        row_factory: Optional[AsyncRowFactory[Any]] = None,
         scrollable: Optional[bool] = None,
         withhold: bool = False,
     ) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]:
index 860521368bd63df62fc01efe236cefaea1ae2771..739fba3b32bfd605b760fb5bd829b1dbe20e313f 100644 (file)
@@ -7,7 +7,7 @@ psycopg cursor objects
 import sys
 from types import TracebackType
 from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
-from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING
+from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING, TypeVar
 from contextlib import contextmanager
 
 from . import pq
@@ -18,7 +18,7 @@ from . import generators
 from .pq import ExecStatus, Format
 from .abc import ConnectionType, Query, Params, PQGen
 from .copy import Copy, AsyncCopy
-from .rows import Row, RowFactory
+from .rows import Row, RowMaker, RowFactory, AsyncRowFactory
 from .compat import asynccontextmanager
 from ._column import Column
 from ._cmodule import _psycopg
@@ -53,17 +53,15 @@ class BaseCursor(Generic[ConnectionType, Row]):
     ExecStatus = pq.ExecStatus
 
     _tx: "Transformer"
+    _make_row: RowMaker[Row]
 
     def __init__(
         self,
         connection: ConnectionType,
-        *,
-        row_factory: RowFactory[Row],
     ):
         self._conn = connection
         self.format = Format.TEXT
         self._adapters = adapt.AdaptersMap(connection.adapters)
-        self._row_factory = row_factory
         self.arraysize = 1
         self._closed = False
         self._last_query: Optional[Query] = None
@@ -162,7 +160,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         if self._iresult < len(self._results):
             self.pgresult = self._results[self._iresult]
             self._tx.set_pgresult(self._results[self._iresult])
-            self._make_row = self._row_factory(self)
+            self._make_row = self._make_row_maker()
             self._pos = 0
             nrows = self.pgresult.command_tuples
             self._rowcount = nrows if nrows is not None else -1
@@ -170,16 +168,8 @@ class BaseCursor(Generic[ConnectionType, Row]):
         else:
             return None
 
-    @property
-    def row_factory(self) -> RowFactory[Row]:
-        """Writable attribute to control how result rows are formed."""
-        return self._row_factory
-
-    @row_factory.setter
-    def row_factory(self, row_factory: RowFactory[Row]) -> None:
-        self._row_factory = row_factory
-        if self.pgresult:
-            self._make_row = row_factory(self)
+    def _make_row_maker(self) -> RowMaker[Row]:
+        raise NotImplementedError
 
     #
     # Generators for the high level operations on the cursor
@@ -276,7 +266,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
             self.pgresult = res
             self._tx.set_pgresult(res, set_loaders=first)
             if first:
-                self._make_row = self._row_factory(self)
+                self._make_row = self._make_row_maker()
             return res
 
         elif res.status in (ExecStatus.TUPLES_OK, ExecStatus.COMMAND_OK):
@@ -379,7 +369,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         self._results = list(results)
         self.pgresult = results[0]
         self._tx.set_pgresult(results[0])
-        self._make_row = self._row_factory(self)
+        self._make_row = self._make_row_maker()
         nrows = self.pgresult.command_tuples
         if nrows is not None:
             if self._rowcount < 0:
@@ -387,8 +377,6 @@ class BaseCursor(Generic[ConnectionType, Row]):
             else:
                 self._rowcount += nrows
 
-        return
-
     def _raise_from_results(self, results: Sequence["PGresult"]) -> NoReturn:
         statuses = {res.status for res in results}
         badstats = statuses.difference(self._status_ok)
@@ -467,11 +455,20 @@ class BaseCursor(Generic[ConnectionType, Row]):
 AnyCursor = BaseCursor[Any, Row]
 
 
+C = TypeVar("C", bound="BaseCursor[Any, Any]")
+
+
 class Cursor(BaseCursor["Connection[Any]", Row]):
     __module__ = "psycopg"
     __slots__ = ()
 
-    def __enter__(self) -> "Cursor[Row]":
+    def __init__(
+        self, connection: "Connection[Any]", *, row_factory: RowFactory[Row]
+    ):
+        super().__init__(connection)
+        self._row_factory = row_factory
+
+    def __enter__(self: C) -> C:
         return self
 
     def __exit__(
@@ -488,13 +485,27 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         """
         self._close()
 
+    @property
+    def row_factory(self) -> RowFactory[Row]:
+        """Writable attribute to control how result rows are formed."""
+        return self._row_factory
+
+    @row_factory.setter
+    def row_factory(self, row_factory: RowFactory[Row]) -> None:
+        self._row_factory = row_factory
+        if self.pgresult:
+            self._make_row = row_factory(self)
+
+    def _make_row_maker(self) -> RowMaker[Row]:
+        return self._row_factory(self)
+
     def execute(
-        self,
+        self: C,
         query: Query,
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
-    ) -> "Cursor[Row]":
+    ) -> C:
         """
         Execute a query or command to the database.
         """
@@ -622,7 +633,16 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
     __module__ = "psycopg"
     __slots__ = ()
 
-    async def __aenter__(self) -> "AsyncCursor[Row]":
+    def __init__(
+        self,
+        connection: "AsyncConnection[Any]",
+        *,
+        row_factory: AsyncRowFactory[Row],
+    ):
+        super().__init__(connection)
+        self._row_factory = row_factory
+
+    async def __aenter__(self: C) -> C:
         return self
 
     async def __aexit__(
@@ -636,13 +656,26 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
     async def close(self) -> None:
         self._close()
 
+    @property
+    def row_factory(self) -> AsyncRowFactory[Row]:
+        return self._row_factory
+
+    @row_factory.setter
+    def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None:
+        self._row_factory = row_factory
+        if self.pgresult:
+            self._make_row = row_factory(self)
+
+    def _make_row_maker(self) -> RowMaker[Row]:
+        return self._row_factory(self)
+
     async def execute(
-        self,
+        self: C,
         query: Query,
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
-    ) -> "AsyncCursor[Row]":
+    ) -> C:
         try:
             async with self._conn.lock:
                 await self._conn.wait(
index 1ffa9893d12191ee8ff02b6137a54a58d02ac35f..5e4927e7f5dab593b26ce60edf759e974221e77c 100644 (file)
@@ -14,7 +14,7 @@ from . import errors as e
 from .compat import Protocol
 
 if TYPE_CHECKING:
-    from .cursor import AnyCursor
+    from .cursor import AnyCursor, Cursor, AsyncCursor
 
 # Row factories
 
@@ -52,7 +52,16 @@ class RowFactory(Protocol[Row]):
     use the values to create a dictionary for each record.
     """
 
-    def __call__(self, __cursor: "AnyCursor[Row]") -> RowMaker[Row]:
+    def __call__(self, __cursor: "Cursor[Row]") -> RowMaker[Row]:
+        ...
+
+
+class AsyncRowFactory(Protocol[Row]):
+    """
+    Callable protocol taking an `~psycopg.AsyncCursor` and returning a `RowMaker`.
+    """
+
+    def __call__(self, __cursor: "AsyncCursor[Row]") -> RowMaker[Row]:
         ...
 
 
index ec185933014d5538b1feb6e62101349c24d5ec57..ae58261edb7acc161c3f8e39932f2d500ece1d1a 100644 (file)
@@ -5,19 +5,17 @@ psycopg server-side cursor objects.
 # Copyright (C) 2020-2021 The Psycopg Team
 
 import warnings
-from types import TracebackType
-from typing import AsyncIterator, Generic, List, Iterator, Optional
-from typing import Sequence, Type, TYPE_CHECKING
+from typing import Any, AsyncIterator, cast, Generic, List, Iterator, Optional
+from typing import Sequence, TYPE_CHECKING
 
 from . import pq
 from . import sql
 from . import errors as e
 from .abc import ConnectionType, Query, Params, PQGen
-from .rows import Row, RowFactory
-from .cursor import BaseCursor, execute
+from .rows import Row, RowFactory, AsyncRowFactory
+from .cursor import C, BaseCursor, Cursor, AsyncCursor, execute
 
 if TYPE_CHECKING:
-    from typing import Any  # noqa: F401
     from .connection import BaseConnection  # noqa: F401
     from .connection import Connection, AsyncConnection  # noqa: F401
 
@@ -175,7 +173,7 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         return sql.SQL(" ").join(parts)
 
 
-class ServerCursor(BaseCursor["Connection[Any]", Row]):
+class ServerCursor(Cursor[Row]):
     __module__ = "psycopg"
     __slots__ = ("_helper", "itersize")
 
@@ -204,17 +202,6 @@ class ServerCursor(BaseCursor["Connection[Any]", Row]):
     def __repr__(self) -> str:
         return self._helper._repr(self)
 
-    def __enter__(self) -> "ServerCursor[Row]":
-        return self
-
-    def __exit__(
-        self,
-        exc_type: Optional[Type[BaseException]],
-        exc_val: Optional[BaseException],
-        exc_tb: Optional[TracebackType],
-    ) -> None:
-        self.close()
-
     @property
     def name(self) -> str:
         """The name of the cursor."""
@@ -245,19 +232,23 @@ class ServerCursor(BaseCursor["Connection[Any]", Row]):
             if self.closed:
                 return
             self._conn.wait(self._helper._close_gen(self))
-            self._close()
+            super().close()
 
     def execute(
-        self,
+        self: C,
         query: Query,
         params: Optional[Params] = None,
-    ) -> "ServerCursor[Row]":
+        **kwargs: Any,
+    ) -> C:
         """
         Open a cursor to execute a query to the database.
         """
-        query = self._helper._make_declare_statement(self, query)
+        if kwargs:
+            raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
+        helper = cast(ServerCursor[Row], self)._helper
+        query = helper._make_declare_statement(self, query)
         with self._conn.lock:
-            self._conn.wait(self._helper._declare_gen(self, query, params))
+            self._conn.wait(helper._declare_gen(self, query, params))
         return self
 
     def executemany(self, query: Query, params_seq: Sequence[Params]) -> None:
@@ -311,7 +302,7 @@ class ServerCursor(BaseCursor["Connection[Any]", Row]):
             self._pos = value
 
 
-class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]):
+class AsyncServerCursor(AsyncCursor[Row]):
     __module__ = "psycopg"
     __slots__ = ("_helper", "itersize")
 
@@ -320,7 +311,7 @@ class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]):
         connection: "AsyncConnection[Any]",
         name: str,
         *,
-        row_factory: RowFactory[Row],
+        row_factory: AsyncRowFactory[Row],
         scrollable: Optional[bool] = None,
         withhold: bool = False,
     ):
@@ -340,17 +331,6 @@ class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]):
     def __repr__(self) -> str:
         return self._helper._repr(self)
 
-    async def __aenter__(self) -> "AsyncServerCursor[Row]":
-        return self
-
-    async def __aexit__(
-        self,
-        exc_type: Optional[Type[BaseException]],
-        exc_val: Optional[BaseException],
-        exc_tb: Optional[TracebackType],
-    ) -> None:
-        await self.close()
-
     @property
     def name(self) -> str:
         return self._helper.name
@@ -368,18 +348,20 @@ class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]):
             if self.closed:
                 return
             await self._conn.wait(self._helper._close_gen(self))
-            self._close()
+            await super().close()
 
     async def execute(
-        self,
+        self: C,
         query: Query,
         params: Optional[Params] = None,
-    ) -> "AsyncServerCursor[Row]":
-        query = self._helper._make_declare_statement(self, query)
+        **kwargs: Any,
+    ) -> C:
+        if kwargs:
+            raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
+        helper = cast(AsyncServerCursor[Row], self)._helper
+        query = helper._make_declare_statement(self, query)
         async with self._conn.lock:
-            await self._conn.wait(
-                self._helper._declare_gen(self, query, params)
-            )
+            await self._conn.wait(helper._declare_gen(self, query, params))
         return self
 
     async def executemany(
index b0d4c3e611fc48397eb22a93df73e86771cdb30e..aeb7a7a6bdc0873ac2b5e236e67b9fd7889f3960 100644 (file)
@@ -3,12 +3,15 @@
 from __future__ import annotations
 
 from dataclasses import dataclass
-from typing import Any, Callable, Optional, Sequence, Tuple
+from typing import Any, Callable, Optional, Sequence, Tuple, Union
 
-from psycopg import AnyCursor, Connection, Cursor, ServerCursor, connect
+from psycopg import Connection, Cursor, ServerCursor, connect
+from psycopg import AsyncConnection, AsyncCursor, AsyncServerCursor
 
 
-def int_row_factory(cursor: AnyCursor[int]) -> Callable[[Sequence[int]], int]:
+def int_row_factory(
+    cursor: Union[Cursor[int], AsyncCursor[int]]
+) -> Callable[[Sequence[int]], int]:
     return lambda values: values[0] if values else 42
 
 
@@ -19,7 +22,7 @@ class Person:
 
     @classmethod
     def row_factory(
-        cls, cursor: AnyCursor[Person]
+        cls, cursor: Union[Cursor[Person], AsyncCursor[Person]]
     ) -> Callable[[Sequence[str]], Person]:
         def mkrow(values: Sequence[str]) -> Person:
             name, address = values
@@ -53,6 +56,31 @@ def check_row_factory_cursor() -> None:
         persons[0].address
 
 
+async def async_check_row_factory_cursor() -> None:
+    """Type-check connection.cursor(..., row_factory=<MyRowFactory>) case."""
+    conn = await AsyncConnection.connect()
+
+    cur1: AsyncCursor[Any]
+    cur1 = conn.cursor()
+    r1: Optional[Any]
+    r1 = await cur1.fetchone()
+    r1 is not None
+
+    cur2: AsyncCursor[int]
+    r2: Optional[int]
+    async with conn.cursor(row_factory=int_row_factory) as cur2:
+        await cur2.execute("select 1")
+        r2 = await cur2.fetchone()
+        r2 and r2 > 0
+
+    cur3: AsyncServerCursor[Person]
+    persons: Sequence[Person]
+    async with conn.cursor(name="s", row_factory=Person.row_factory) as cur3:
+        await cur3.execute("select * from persons where name like 'al%'")
+        persons = await cur3.fetchall()
+        persons[0].address
+
+
 def check_row_factory_connection() -> None:
     """Type-check connect(..., row_factory=<MyRowFactory>) or
     Connection.row_factory cases.
@@ -85,3 +113,37 @@ def check_row_factory_connection() -> None:
         cur3.execute("select 42")
         r3 = cur3.fetchone()
         r3 and len(r3)
+
+
+async def async_check_row_factory_connection() -> None:
+    """Type-check connect(..., row_factory=<MyRowFactory>) or
+    Connection.row_factory cases.
+    """
+    conn1: AsyncConnection[int]
+    cur1: AsyncCursor[int]
+    r1: Optional[int]
+    conn1 = await AsyncConnection.connect(row_factory=int_row_factory)
+    cur1 = await conn1.execute("select 1")
+    r1 = await cur1.fetchone()
+    r1 != 0
+    async with conn1.cursor() as cur1:
+        await cur1.execute("select 2")
+
+    conn2: AsyncConnection[Person]
+    cur2: AsyncCursor[Person]
+    r2: Optional[Person]
+    conn2 = await AsyncConnection.connect(row_factory=Person.row_factory)
+    cur2 = await conn2.execute("select * from persons")
+    r2 = await cur2.fetchone()
+    r2 and r2.name
+    async with conn2.cursor() as cur2:
+        await cur2.execute("select 2")
+
+    cur3: AsyncCursor[Tuple[Any, ...]]
+    r3: Optional[Tuple[Any, ...]]
+    conn3 = await AsyncConnection.connect()
+    cur3 = await conn3.execute("select 3")
+    async with conn3.cursor() as cur3:
+        await cur3.execute("select 42")
+        r3 = await cur3.fetchone()
+        r3 and len(r3)