From: Daniele Varrazzo Date: Wed, 4 May 2022 18:09:31 +0000 (+0200) Subject: feat: make Cursor constructor public X-Git-Tag: 3.1~108 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c04da6b6237738392ecb9f19d926f49911d2e75c;p=thirdparty%2Fpsycopg.git feat: make Cursor constructor public This allows a more natural creation of cursors subclasses, without the need of tweaking the connection's cursor_factory and allowing to pass arbitrary init arguments to the cursor. --- diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 7a169207d..ff4d1ea7e 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -643,11 +643,15 @@ class Connection(BaseConnection[Row]): _pipeline: Optional[Pipeline] - def __init__(self, pgconn: "PGconn", row_factory: Optional[RowFactory[Row]] = None): + def __init__( + self, + pgconn: "PGconn", + row_factory: RowFactory[Row] = cast(RowFactory[Row], tuple_row), + ): super().__init__(pgconn) - self.row_factory = row_factory or cast(RowFactory[Row], tuple_row) + self.row_factory = row_factory self.lock = threading.Lock() - self.cursor_factory = Cursor + self.cursor_factory = cast("Type[Cursor[Row]]", Cursor) self.server_cursor_factory = ServerCursor @overload diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index ef6c2bbd7..2cbaa745d 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -51,12 +51,12 @@ class AsyncConnection(BaseConnection[Row]): def __init__( self, pgconn: "PGconn", - row_factory: Optional[AsyncRowFactory[Row]] = None, + row_factory: AsyncRowFactory[Row] = cast(AsyncRowFactory[Row], tuple_row), ): super().__init__(pgconn) - self.row_factory = row_factory or cast(AsyncRowFactory[Row], tuple_row) + self.row_factory = row_factory self.lock = asyncio.Lock() - self.cursor_factory = AsyncCursor + self.cursor_factory = cast("Type[AsyncCursor[Row]]", AsyncCursor) self.server_cursor_factory = AsyncServerCursor @overload diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 459351581..22a10e9e4 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -7,7 +7,8 @@ psycopg cursor objects from functools import partial from types import TracebackType from typing import Any, Generic, Iterable, Iterator, List -from typing import Optional, NoReturn, Sequence, Type, TypeVar, TYPE_CHECKING +from typing import Optional, NoReturn, Sequence, Type, TypeVar +from typing import overload, TYPE_CHECKING from contextlib import contextmanager from . import pq @@ -637,9 +638,27 @@ class Cursor(BaseCursor["Connection[Any]", Row]): __module__ = "psycopg" __slots__ = () - def __init__(self, connection: "Connection[Any]", *, row_factory: RowFactory[Row]): + @overload + def __init__(self: "Cursor[Row]", connection: "Connection[Row]"): + ... + + @overload + def __init__( + self: "Cursor[Row]", + connection: "Connection[Any]", + *, + row_factory: RowFactory[Row], + ): + ... + + def __init__( + self, + connection: "Connection[Any]", + *, + row_factory: Optional[RowFactory[Row]] = None, + ): super().__init__(connection) - self._row_factory = row_factory + self._row_factory = row_factory or connection.row_factory def __enter__(self: _C) -> _C: return self diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index 8c4ac4d9e..d6c7dbbba 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -6,7 +6,7 @@ psycopg async cursor objects from types import TracebackType from typing import Any, AsyncIterator, Iterable, List -from typing import Optional, Type, TypeVar, TYPE_CHECKING +from typing import Optional, Type, TypeVar, TYPE_CHECKING, overload from contextlib import asynccontextmanager from . import errors as e @@ -26,14 +26,27 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): __module__ = "psycopg" __slots__ = () + @overload + def __init__(self: "AsyncCursor[Row]", connection: "AsyncConnection[Row]"): + ... + + @overload def __init__( - self, + self: "AsyncCursor[Row]", connection: "AsyncConnection[Any]", *, row_factory: AsyncRowFactory[Row], + ): + ... + + def __init__( + self, + connection: "AsyncConnection[Any]", + *, + row_factory: Optional[AsyncRowFactory[Row]] = None, ): super().__init__(connection) - self._row_factory = row_factory + self._row_factory = row_factory or connection.row_factory async def __aenter__(self: _C) -> _C: return self diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py index 9561297f1..d7913090a 100644 --- a/psycopg/psycopg/server_cursor.py +++ b/psycopg/psycopg/server_cursor.py @@ -5,7 +5,7 @@ psycopg server-side cursor objects. # Copyright (C) 2020 The Psycopg Team from typing import Any, AsyncIterator, Generic, List, Iterable, Iterator -from typing import Optional, TypeVar, TYPE_CHECKING +from typing import Optional, TypeVar, TYPE_CHECKING, overload from warnings import warn from . import pq @@ -182,8 +182,20 @@ class ServerCursor(Cursor[Row]): __module__ = "psycopg" __slots__ = ("_helper", "itersize") + @overload def __init__( - self, + self: "ServerCursor[Row]", + connection: "Connection[Row]", + name: str, + *, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + ... + + @overload + def __init__( + self: "ServerCursor[Row]", connection: "Connection[Any]", name: str, *, @@ -191,7 +203,18 @@ class ServerCursor(Cursor[Row]): scrollable: Optional[bool] = None, withhold: bool = False, ): - super().__init__(connection, row_factory=row_factory) + ... + + def __init__( + self, + connection: "Connection[Any]", + name: str, + *, + row_factory: Optional[RowFactory[Row]] = None, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + super().__init__(connection, row_factory=row_factory or connection.row_factory) self._helper: ServerCursorHelper["Connection[Any]", Row] self._helper = ServerCursorHelper(name, scrollable, withhold) self.itersize: int = DEFAULT_ITERSIZE @@ -323,8 +346,20 @@ class AsyncServerCursor(AsyncCursor[Row]): __module__ = "psycopg" __slots__ = ("_helper", "itersize") + @overload def __init__( - self, + self: "AsyncServerCursor[Row]", + connection: "AsyncConnection[Row]", + name: str, + *, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + ... + + @overload + def __init__( + self: "AsyncServerCursor[Row]", connection: "AsyncConnection[Any]", name: str, *, @@ -332,7 +367,18 @@ class AsyncServerCursor(AsyncCursor[Row]): scrollable: Optional[bool] = None, withhold: bool = False, ): - super().__init__(connection, row_factory=row_factory) + ... + + def __init__( + self, + connection: "AsyncConnection[Any]", + name: str, + *, + row_factory: Optional[AsyncRowFactory[Row]] = None, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + super().__init__(connection, row_factory=row_factory or connection.row_factory) self._helper: ServerCursorHelper["AsyncConnection[Any]", Row] self._helper = ServerCursorHelper(name, scrollable, withhold) self.itersize: int = DEFAULT_ITERSIZE diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 1d1b6275d..a52b3fc37 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -15,6 +15,23 @@ from psycopg.rows import RowMaker from .utils import gc_collect +def test_init(conn): + cur = psycopg.Cursor(conn) + cur.execute("select 1") + assert cur.fetchone() == (1,) + + conn.row_factory = rows.dict_row + cur = psycopg.Cursor(conn) + cur.execute("select 1 as a") + assert cur.fetchone() == {"a": 1} + + +def test_init_factory(conn): + cur = psycopg.Cursor(conn, row_factory=rows.dict_row) + cur.execute("select 1 as a") + assert cur.fetchone() == {"a": 1} + + def test_close(conn): cur = conn.cursor() assert not cur.closed diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index 78aeafd5a..ea5bf7b13 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -16,6 +16,23 @@ execmany = execmany # avoid F811 underneath pytestmark = pytest.mark.asyncio +async def test_init(aconn): + cur = psycopg.AsyncCursor(aconn) + await cur.execute("select 1") + assert (await cur.fetchone()) == (1,) + + aconn.row_factory = rows.dict_row + cur = psycopg.AsyncCursor(aconn) + await cur.execute("select 1 as a") + assert (await cur.fetchone()) == {"a": 1} + + +async def test_init_factory(aconn): + cur = psycopg.AsyncCursor(aconn, row_factory=rows.dict_row) + await cur.execute("select 1 as a") + assert (await cur.fetchone()) == {"a": 1} + + async def test_close(aconn): cur = aconn.cursor() assert not cur.closed diff --git a/tests/test_server_cursor.py b/tests/test_server_cursor.py index 9b77fa852..78b51ce00 100644 --- a/tests/test_server_cursor.py +++ b/tests/test_server_cursor.py @@ -1,8 +1,35 @@ import pytest -from psycopg import errors as e +import psycopg +from psycopg import rows, errors as e from psycopg.pq import Format -from psycopg.rows import dict_row + + +def test_init_row_factory(conn): + with psycopg.ServerCursor(conn, "foo") as cur: + assert cur.name == "foo" + assert cur.connection is conn + assert cur.row_factory is conn.row_factory + + conn.row_factory = rows.dict_row + + with psycopg.ServerCursor(conn, "bar") as cur: + assert cur.name == "bar" + assert cur.row_factory is rows.dict_row # type: ignore + + with psycopg.ServerCursor(conn, "baz", row_factory=rows.namedtuple_row) as cur: + assert cur.name == "baz" + assert cur.row_factory is rows.namedtuple_row # type: ignore + + +def test_init_params(conn): + with psycopg.ServerCursor(conn, "foo") as cur: + assert cur.scrollable is None + assert cur.withhold is False + + with psycopg.ServerCursor(conn, "bar", withhold=True, scrollable=False) as cur: + assert cur.scrollable is False + assert cur.withhold is True def test_funny_name(conn): @@ -310,17 +337,17 @@ def test_row_factory(conn): cur = conn.cursor("foo", row_factory=my_row_factory, scrollable=True) cur.execute("select generate_series(1, 3) as x") - rows = cur.fetchall() + recs = cur.fetchall() cur.scroll(0, "absolute") while 1: - row = cur.fetchone() - if not row: + rec = cur.fetchone() + if not rec: break - rows.append(row) - assert rows == [[1, -1], [1, -2], [1, -3]] * 2 + recs.append(rec) + assert recs == [[1, -1], [1, -2], [1, -3]] * 2 cur.scroll(0, "absolute") - cur.row_factory = dict_row + cur.row_factory = rows.dict_row assert cur.fetchone() == {"x": 1} cur.close() diff --git a/tests/test_server_cursor_async.py b/tests/test_server_cursor_async.py index f169cad5a..6b2b9a5e5 100644 --- a/tests/test_server_cursor_async.py +++ b/tests/test_server_cursor_async.py @@ -1,12 +1,43 @@ import pytest -from psycopg import errors as e -from psycopg.rows import dict_row +import psycopg +from psycopg import rows, errors as e from psycopg.pq import Format pytestmark = pytest.mark.asyncio +async def test_init_row_factory(aconn): + async with psycopg.AsyncServerCursor(aconn, "foo") as cur: + assert cur.name == "foo" + assert cur.connection is aconn + assert cur.row_factory is aconn.row_factory + + aconn.row_factory = rows.dict_row + + async with psycopg.AsyncServerCursor(aconn, "bar") as cur: + assert cur.name == "bar" + assert cur.row_factory is rows.dict_row # type: ignore + + async with psycopg.AsyncServerCursor( + aconn, "baz", row_factory=rows.namedtuple_row + ) as cur: + assert cur.name == "baz" + assert cur.row_factory is rows.namedtuple_row # type: ignore + + +async def test_init_params(aconn): + async with psycopg.AsyncServerCursor(aconn, "foo") as cur: + assert cur.scrollable is None + assert cur.withhold is False + + async with psycopg.AsyncServerCursor( + aconn, "bar", withhold=True, scrollable=False + ) as cur: + assert cur.scrollable is False + assert cur.withhold is True + + async def test_funny_name(aconn): cur = aconn.cursor("1-2-3") await cur.execute("select generate_series(1, 3) as bar") @@ -316,17 +347,17 @@ async def test_row_factory(aconn): cur = aconn.cursor("foo", row_factory=my_row_factory, scrollable=True) await cur.execute("select generate_series(1, 3) as x") - rows = await cur.fetchall() + recs = await cur.fetchall() await cur.scroll(0, "absolute") while 1: - row = await cur.fetchone() - if not row: + rec = await cur.fetchone() + if not rec: break - rows.append(row) - assert rows == [[1, -1], [1, -2], [1, -3]] * 2 + recs.append(rec) + assert recs == [[1, -1], [1, -2], [1, -3]] * 2 await cur.scroll(0, "absolute") - cur.row_factory = dict_row + cur.row_factory = rows.dict_row assert await cur.fetchone() == {"x": 1} await cur.close() diff --git a/tests/test_typing.py b/tests/test_typing.py index c8c664bb2..6421db911 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -138,9 +138,9 @@ def test_connection_type(conn, type, mypy): "psycopg.AsyncServerCursor[Dict[str, Any]]", ), ( - "psycopg.connect()", + "await psycopg.AsyncConnection.connect()", "conn.cursor(name='foo', row_factory=rows.dict_row)", - "psycopg.ServerCursor[Dict[str, Any]]", + "psycopg.AsyncServerCursor[Dict[str, Any]]", ), ], ) @@ -152,6 +152,82 @@ obj = {curs} _test_reveal(stmts, type, mypy) +@pytest.mark.parametrize( + "conn, curs, type", + [ + ( + "psycopg.connect()", + "psycopg.Cursor(conn)", + "psycopg.Cursor[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.Cursor(conn)", + "psycopg.Cursor[Dict[str, Any]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.Cursor(conn, row_factory=rows.namedtuple_row)", + "psycopg.Cursor[NamedTuple]", + ), + # Async cursors + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncCursor(conn)", + "psycopg.AsyncCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)", + "psycopg.AsyncCursor(conn)", + "psycopg.AsyncCursor[Dict[str, Any]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncCursor(conn, row_factory=thing_row)", + "psycopg.AsyncCursor[Thing]", + ), + # Server-side cursors + ( + "psycopg.connect()", + "psycopg.ServerCursor(conn, 'foo')", + "psycopg.ServerCursor[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.ServerCursor(conn, name='foo')", + "psycopg.ServerCursor[Dict[str, Any]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.ServerCursor(conn, 'foo', row_factory=rows.namedtuple_row)", + "psycopg.ServerCursor[NamedTuple]", + ), + # Async server-side cursors + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncServerCursor(conn, name='foo')", + "psycopg.AsyncServerCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)", + "psycopg.AsyncServerCursor(conn, name='foo')", + "psycopg.AsyncServerCursor[Dict[str, Any]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncServerCursor(conn, name='foo', row_factory=rows.dict_row)", + "psycopg.AsyncServerCursor[Dict[str, Any]]", + ), + ], +) +def test_cursor_type_init(conn, curs, type, mypy): + stmts = f"""\ +conn = {conn} +obj = {curs} +""" + _test_reveal(stmts, type, mypy) + + @pytest.mark.parametrize( "curs, type", [