]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: make Cursor constructor public
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 4 May 2022 18:09:31 +0000 (20:09 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 12 May 2022 13:36:06 +0000 (15:36 +0200)
This allows a more natural creation of cursors subclasses, without the
need of tweaking the connection's cursor_factory and allowing to pass
arbitrary init arguments to the cursor.

psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg/psycopg/server_cursor.py
tests/test_cursor.py
tests/test_cursor_async.py
tests/test_server_cursor.py
tests/test_server_cursor_async.py
tests/test_typing.py

index 7a169207d95ff95590508b620bf4d2684dad48d7..ff4d1ea7e5c930139e9aa087b97adff1ed4e0213 100644 (file)
@@ -643,11 +643,15 @@ class Connection(BaseConnection[Row]):
 
     _pipeline: Optional[Pipeline]
 
-    def __init__(self, pgconn: "PGconn", row_factory: Optional[RowFactory[Row]] = None):
+    def __init__(
+        self,
+        pgconn: "PGconn",
+        row_factory: RowFactory[Row] = cast(RowFactory[Row], tuple_row),
+    ):
         super().__init__(pgconn)
-        self.row_factory = row_factory or cast(RowFactory[Row], tuple_row)
+        self.row_factory = row_factory
         self.lock = threading.Lock()
-        self.cursor_factory = Cursor
+        self.cursor_factory = cast("Type[Cursor[Row]]", Cursor)
         self.server_cursor_factory = ServerCursor
 
     @overload
index ef6c2bbd75a5ea4b79720e6ebaeb5a564182dfde..2cbaa745d80a3fac4d44ef3f9dfb19f4ecf06b31 100644 (file)
@@ -51,12 +51,12 @@ class AsyncConnection(BaseConnection[Row]):
     def __init__(
         self,
         pgconn: "PGconn",
-        row_factory: Optional[AsyncRowFactory[Row]] = None,
+        row_factory: AsyncRowFactory[Row] = cast(AsyncRowFactory[Row], tuple_row),
     ):
         super().__init__(pgconn)
-        self.row_factory = row_factory or cast(AsyncRowFactory[Row], tuple_row)
+        self.row_factory = row_factory
         self.lock = asyncio.Lock()
-        self.cursor_factory = AsyncCursor
+        self.cursor_factory = cast("Type[AsyncCursor[Row]]", AsyncCursor)
         self.server_cursor_factory = AsyncServerCursor
 
     @overload
index 4593515819cb99bed6142536e5e7a1372fe2879d..22a10e9e4b51aec30827b9120e2eb8e857fac175 100644 (file)
@@ -7,7 +7,8 @@ psycopg cursor objects
 from functools import partial
 from types import TracebackType
 from typing import Any, Generic, Iterable, Iterator, List
-from typing import Optional, NoReturn, Sequence, Type, TypeVar, TYPE_CHECKING
+from typing import Optional, NoReturn, Sequence, Type, TypeVar
+from typing import overload, TYPE_CHECKING
 from contextlib import contextmanager
 
 from . import pq
@@ -637,9 +638,27 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
     __module__ = "psycopg"
     __slots__ = ()
 
-    def __init__(self, connection: "Connection[Any]", *, row_factory: RowFactory[Row]):
+    @overload
+    def __init__(self: "Cursor[Row]", connection: "Connection[Row]"):
+        ...
+
+    @overload
+    def __init__(
+        self: "Cursor[Row]",
+        connection: "Connection[Any]",
+        *,
+        row_factory: RowFactory[Row],
+    ):
+        ...
+
+    def __init__(
+        self,
+        connection: "Connection[Any]",
+        *,
+        row_factory: Optional[RowFactory[Row]] = None,
+    ):
         super().__init__(connection)
-        self._row_factory = row_factory
+        self._row_factory = row_factory or connection.row_factory
 
     def __enter__(self: _C) -> _C:
         return self
index 8c4ac4d9e66b269727dc9d427e9b5a57f8420dcd..d6c7dbbba0e0f863a9df117868599d6c500b89b9 100644 (file)
@@ -6,7 +6,7 @@ psycopg async cursor objects
 
 from types import TracebackType
 from typing import Any, AsyncIterator, Iterable, List
-from typing import Optional, Type, TypeVar, TYPE_CHECKING
+from typing import Optional, Type, TypeVar, TYPE_CHECKING, overload
 from contextlib import asynccontextmanager
 
 from . import errors as e
@@ -26,14 +26,27 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
     __module__ = "psycopg"
     __slots__ = ()
 
+    @overload
+    def __init__(self: "AsyncCursor[Row]", connection: "AsyncConnection[Row]"):
+        ...
+
+    @overload
     def __init__(
-        self,
+        self: "AsyncCursor[Row]",
         connection: "AsyncConnection[Any]",
         *,
         row_factory: AsyncRowFactory[Row],
+    ):
+        ...
+
+    def __init__(
+        self,
+        connection: "AsyncConnection[Any]",
+        *,
+        row_factory: Optional[AsyncRowFactory[Row]] = None,
     ):
         super().__init__(connection)
-        self._row_factory = row_factory
+        self._row_factory = row_factory or connection.row_factory
 
     async def __aenter__(self: _C) -> _C:
         return self
index 9561297f1351562cbe60b952048da8ffa251b77b..d7913090a34dd0aff52f9ae00f24f7478b32e05b 100644 (file)
@@ -5,7 +5,7 @@ psycopg server-side cursor objects.
 # Copyright (C) 2020 The Psycopg Team
 
 from typing import Any, AsyncIterator, Generic, List, Iterable, Iterator
-from typing import Optional, TypeVar, TYPE_CHECKING
+from typing import Optional, TypeVar, TYPE_CHECKING, overload
 from warnings import warn
 
 from . import pq
@@ -182,8 +182,20 @@ class ServerCursor(Cursor[Row]):
     __module__ = "psycopg"
     __slots__ = ("_helper", "itersize")
 
+    @overload
     def __init__(
-        self,
+        self: "ServerCursor[Row]",
+        connection: "Connection[Row]",
+        name: str,
+        *,
+        scrollable: Optional[bool] = None,
+        withhold: bool = False,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self: "ServerCursor[Row]",
         connection: "Connection[Any]",
         name: str,
         *,
@@ -191,7 +203,18 @@ class ServerCursor(Cursor[Row]):
         scrollable: Optional[bool] = None,
         withhold: bool = False,
     ):
-        super().__init__(connection, row_factory=row_factory)
+        ...
+
+    def __init__(
+        self,
+        connection: "Connection[Any]",
+        name: str,
+        *,
+        row_factory: Optional[RowFactory[Row]] = None,
+        scrollable: Optional[bool] = None,
+        withhold: bool = False,
+    ):
+        super().__init__(connection, row_factory=row_factory or connection.row_factory)
         self._helper: ServerCursorHelper["Connection[Any]", Row]
         self._helper = ServerCursorHelper(name, scrollable, withhold)
         self.itersize: int = DEFAULT_ITERSIZE
@@ -323,8 +346,20 @@ class AsyncServerCursor(AsyncCursor[Row]):
     __module__ = "psycopg"
     __slots__ = ("_helper", "itersize")
 
+    @overload
     def __init__(
-        self,
+        self: "AsyncServerCursor[Row]",
+        connection: "AsyncConnection[Row]",
+        name: str,
+        *,
+        scrollable: Optional[bool] = None,
+        withhold: bool = False,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self: "AsyncServerCursor[Row]",
         connection: "AsyncConnection[Any]",
         name: str,
         *,
@@ -332,7 +367,18 @@ class AsyncServerCursor(AsyncCursor[Row]):
         scrollable: Optional[bool] = None,
         withhold: bool = False,
     ):
-        super().__init__(connection, row_factory=row_factory)
+        ...
+
+    def __init__(
+        self,
+        connection: "AsyncConnection[Any]",
+        name: str,
+        *,
+        row_factory: Optional[AsyncRowFactory[Row]] = None,
+        scrollable: Optional[bool] = None,
+        withhold: bool = False,
+    ):
+        super().__init__(connection, row_factory=row_factory or connection.row_factory)
         self._helper: ServerCursorHelper["AsyncConnection[Any]", Row]
         self._helper = ServerCursorHelper(name, scrollable, withhold)
         self.itersize: int = DEFAULT_ITERSIZE
index 1d1b6275d038ecc97622a3a056c75491532f7660..a52b3fc374b412481039cfdb63937bbb4402e25b 100644 (file)
@@ -15,6 +15,23 @@ from psycopg.rows import RowMaker
 from .utils import gc_collect
 
 
+def test_init(conn):
+    cur = psycopg.Cursor(conn)
+    cur.execute("select 1")
+    assert cur.fetchone() == (1,)
+
+    conn.row_factory = rows.dict_row
+    cur = psycopg.Cursor(conn)
+    cur.execute("select 1 as a")
+    assert cur.fetchone() == {"a": 1}
+
+
+def test_init_factory(conn):
+    cur = psycopg.Cursor(conn, row_factory=rows.dict_row)
+    cur.execute("select 1 as a")
+    assert cur.fetchone() == {"a": 1}
+
+
 def test_close(conn):
     cur = conn.cursor()
     assert not cur.closed
index 78aeafd5a3faf3a7cb0f5d556b8fb90a17bcb299..ea5bf7b139df02342f4e38cebfe346adbfe8f6d4 100644 (file)
@@ -16,6 +16,23 @@ execmany = execmany  # avoid F811 underneath
 pytestmark = pytest.mark.asyncio
 
 
+async def test_init(aconn):
+    cur = psycopg.AsyncCursor(aconn)
+    await cur.execute("select 1")
+    assert (await cur.fetchone()) == (1,)
+
+    aconn.row_factory = rows.dict_row
+    cur = psycopg.AsyncCursor(aconn)
+    await cur.execute("select 1 as a")
+    assert (await cur.fetchone()) == {"a": 1}
+
+
+async def test_init_factory(aconn):
+    cur = psycopg.AsyncCursor(aconn, row_factory=rows.dict_row)
+    await cur.execute("select 1 as a")
+    assert (await cur.fetchone()) == {"a": 1}
+
+
 async def test_close(aconn):
     cur = aconn.cursor()
     assert not cur.closed
index 9b77fa8527d46b773031b5119ca63c22b2d4b067..78b51ce00bfc2bfa10c3f74eb1fe04672bbfa98b 100644 (file)
@@ -1,8 +1,35 @@
 import pytest
 
-from psycopg import errors as e
+import psycopg
+from psycopg import rows, errors as e
 from psycopg.pq import Format
-from psycopg.rows import dict_row
+
+
+def test_init_row_factory(conn):
+    with psycopg.ServerCursor(conn, "foo") as cur:
+        assert cur.name == "foo"
+        assert cur.connection is conn
+        assert cur.row_factory is conn.row_factory
+
+    conn.row_factory = rows.dict_row
+
+    with psycopg.ServerCursor(conn, "bar") as cur:
+        assert cur.name == "bar"
+        assert cur.row_factory is rows.dict_row  # type: ignore
+
+    with psycopg.ServerCursor(conn, "baz", row_factory=rows.namedtuple_row) as cur:
+        assert cur.name == "baz"
+        assert cur.row_factory is rows.namedtuple_row  # type: ignore
+
+
+def test_init_params(conn):
+    with psycopg.ServerCursor(conn, "foo") as cur:
+        assert cur.scrollable is None
+        assert cur.withhold is False
+
+    with psycopg.ServerCursor(conn, "bar", withhold=True, scrollable=False) as cur:
+        assert cur.scrollable is False
+        assert cur.withhold is True
 
 
 def test_funny_name(conn):
@@ -310,17 +337,17 @@ def test_row_factory(conn):
 
     cur = conn.cursor("foo", row_factory=my_row_factory, scrollable=True)
     cur.execute("select generate_series(1, 3) as x")
-    rows = cur.fetchall()
+    recs = cur.fetchall()
     cur.scroll(0, "absolute")
     while 1:
-        row = cur.fetchone()
-        if not row:
+        rec = cur.fetchone()
+        if not rec:
             break
-        rows.append(row)
-    assert rows == [[1, -1], [1, -2], [1, -3]] * 2
+        recs.append(rec)
+    assert recs == [[1, -1], [1, -2], [1, -3]] * 2
 
     cur.scroll(0, "absolute")
-    cur.row_factory = dict_row
+    cur.row_factory = rows.dict_row
     assert cur.fetchone() == {"x": 1}
     cur.close()
 
index f169cad5ae771b1c783f22ef6e230f0a705146e8..6b2b9a5e52bfb71b811a055bff8a3b6dd4075ace 100644 (file)
@@ -1,12 +1,43 @@
 import pytest
 
-from psycopg import errors as e
-from psycopg.rows import dict_row
+import psycopg
+from psycopg import rows, errors as e
 from psycopg.pq import Format
 
 pytestmark = pytest.mark.asyncio
 
 
+async def test_init_row_factory(aconn):
+    async with psycopg.AsyncServerCursor(aconn, "foo") as cur:
+        assert cur.name == "foo"
+        assert cur.connection is aconn
+        assert cur.row_factory is aconn.row_factory
+
+    aconn.row_factory = rows.dict_row
+
+    async with psycopg.AsyncServerCursor(aconn, "bar") as cur:
+        assert cur.name == "bar"
+        assert cur.row_factory is rows.dict_row  # type: ignore
+
+    async with psycopg.AsyncServerCursor(
+        aconn, "baz", row_factory=rows.namedtuple_row
+    ) as cur:
+        assert cur.name == "baz"
+        assert cur.row_factory is rows.namedtuple_row  # type: ignore
+
+
+async def test_init_params(aconn):
+    async with psycopg.AsyncServerCursor(aconn, "foo") as cur:
+        assert cur.scrollable is None
+        assert cur.withhold is False
+
+    async with psycopg.AsyncServerCursor(
+        aconn, "bar", withhold=True, scrollable=False
+    ) as cur:
+        assert cur.scrollable is False
+        assert cur.withhold is True
+
+
 async def test_funny_name(aconn):
     cur = aconn.cursor("1-2-3")
     await cur.execute("select generate_series(1, 3) as bar")
@@ -316,17 +347,17 @@ async def test_row_factory(aconn):
 
     cur = aconn.cursor("foo", row_factory=my_row_factory, scrollable=True)
     await cur.execute("select generate_series(1, 3) as x")
-    rows = await cur.fetchall()
+    recs = await cur.fetchall()
     await cur.scroll(0, "absolute")
     while 1:
-        row = await cur.fetchone()
-        if not row:
+        rec = await cur.fetchone()
+        if not rec:
             break
-        rows.append(row)
-    assert rows == [[1, -1], [1, -2], [1, -3]] * 2
+        recs.append(rec)
+    assert recs == [[1, -1], [1, -2], [1, -3]] * 2
 
     await cur.scroll(0, "absolute")
-    cur.row_factory = dict_row
+    cur.row_factory = rows.dict_row
     assert await cur.fetchone() == {"x": 1}
     await cur.close()
 
index c8c664bb28cff5fb22009a981dccda74c0cf6d5d..6421db91183963b33a9bc0eb7d3e9f97b41515d5 100644 (file)
@@ -138,9 +138,9 @@ def test_connection_type(conn, type, mypy):
             "psycopg.AsyncServerCursor[Dict[str, Any]]",
         ),
         (
-            "psycopg.connect()",
+            "await psycopg.AsyncConnection.connect()",
             "conn.cursor(name='foo', row_factory=rows.dict_row)",
-            "psycopg.ServerCursor[Dict[str, Any]]",
+            "psycopg.AsyncServerCursor[Dict[str, Any]]",
         ),
     ],
 )
@@ -152,6 +152,82 @@ obj = {curs}
     _test_reveal(stmts, type, mypy)
 
 
+@pytest.mark.parametrize(
+    "conn, curs, type",
+    [
+        (
+            "psycopg.connect()",
+            "psycopg.Cursor(conn)",
+            "psycopg.Cursor[Tuple[Any, ...]]",
+        ),
+        (
+            "psycopg.connect(row_factory=rows.dict_row)",
+            "psycopg.Cursor(conn)",
+            "psycopg.Cursor[Dict[str, Any]]",
+        ),
+        (
+            "psycopg.connect(row_factory=rows.dict_row)",
+            "psycopg.Cursor(conn, row_factory=rows.namedtuple_row)",
+            "psycopg.Cursor[NamedTuple]",
+        ),
+        # Async cursors
+        (
+            "await psycopg.AsyncConnection.connect()",
+            "psycopg.AsyncCursor(conn)",
+            "psycopg.AsyncCursor[Tuple[Any, ...]]",
+        ),
+        (
+            "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)",
+            "psycopg.AsyncCursor(conn)",
+            "psycopg.AsyncCursor[Dict[str, Any]]",
+        ),
+        (
+            "await psycopg.AsyncConnection.connect()",
+            "psycopg.AsyncCursor(conn, row_factory=thing_row)",
+            "psycopg.AsyncCursor[Thing]",
+        ),
+        # Server-side cursors
+        (
+            "psycopg.connect()",
+            "psycopg.ServerCursor(conn, 'foo')",
+            "psycopg.ServerCursor[Tuple[Any, ...]]",
+        ),
+        (
+            "psycopg.connect(row_factory=rows.dict_row)",
+            "psycopg.ServerCursor(conn, name='foo')",
+            "psycopg.ServerCursor[Dict[str, Any]]",
+        ),
+        (
+            "psycopg.connect(row_factory=rows.dict_row)",
+            "psycopg.ServerCursor(conn, 'foo', row_factory=rows.namedtuple_row)",
+            "psycopg.ServerCursor[NamedTuple]",
+        ),
+        # Async server-side cursors
+        (
+            "await psycopg.AsyncConnection.connect()",
+            "psycopg.AsyncServerCursor(conn, name='foo')",
+            "psycopg.AsyncServerCursor[Tuple[Any, ...]]",
+        ),
+        (
+            "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)",
+            "psycopg.AsyncServerCursor(conn, name='foo')",
+            "psycopg.AsyncServerCursor[Dict[str, Any]]",
+        ),
+        (
+            "await psycopg.AsyncConnection.connect()",
+            "psycopg.AsyncServerCursor(conn, name='foo', row_factory=rows.dict_row)",
+            "psycopg.AsyncServerCursor[Dict[str, Any]]",
+        ),
+    ],
+)
+def test_cursor_type_init(conn, curs, type, mypy):
+    stmts = f"""\
+conn = {conn}
+obj = {curs}
+"""
+    _test_reveal(stmts, type, mypy)
+
+
 @pytest.mark.parametrize(
     "curs, type",
     [