]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Move withhold/scrollable as server-side cursor attributes
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 22 Jul 2021 15:41:50 +0000 (17:41 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 23 Jul 2021 14:38:54 +0000 (16:38 +0200)
Add respective properties to read back the state.

docs/api/connections.rst
docs/api/cursors.rst
psycopg/psycopg/connection.py
psycopg/psycopg/server_cursor.py
tests/test_server_cursor.py
tests/test_server_cursor_async.py

index c257a7b849dbbde037eac7cdbc27c68151ed3be0..cdc0965232d6ccb94dee6c09ce8a23cc466c4d2f 100644 (file)
@@ -58,14 +58,24 @@ The `!Connection` class
 
     .. automethod:: close
 
-        .. note:: You can use :ref:`with connect(): ...<with-connection>` to
-            close the connection automatically when the block is exited.
+        .. note::
+
+            You can use::
+
+                with psycopg.connect() as conn:
+                    ...
+
+            to close the connection automatically when the block is exited.
+            See :ref:`with-connection`.
 
     .. autoattribute:: closed
     .. autoattribute:: broken
 
-    .. method:: cursor(*, binary: bool = False, row_factory: Optional[RowFactory] = None) -> Cursor
-    .. method:: cursor(name: str, *, binary: bool = False, row_factory: Optional[RowFactory] = None) -> ServerCursor
+    .. method:: cursor(*, binary: bool = False, \
+           row_factory: Optional[RowFactory] = None) -> Cursor
+    .. method:: cursor(name: str, *, binary: bool = False, \
+            row_factory: Optional[RowFactory] = None, \
+            scrollable: Optional[bool] = None, withhold: bool = False) -> ServerCursor
         :noindex:
 
         Return a new cursor to send commands and queries to the connection.
@@ -236,8 +246,11 @@ The `!AsyncConnection` class
             automatically when the block is exited, but be careful about
             the async quirkness: see :ref:`async-with` for details.
 
-    .. method:: cursor(*, binary: bool = False, row_factory: Optional[RowFactory] = None) -> AsyncCursor
-    .. method:: cursor(name: str, *, binary: bool = False, row_factory: Optional[RowFactory] = None) -> AsyncServerCursor
+    .. method:: cursor(*, binary: bool = False, \
+            row_factory: Optional[RowFactory] = None) -> AsyncCursor
+    .. method:: cursor(name: str, *, binary: bool = False, \
+            row_factory: Optional[RowFactory] = None, \
+            scrollable: Optional[bool] = None, withhold: bool = False) -> AsyncServerCursor
         :noindex:
 
         .. note:: You can use ``async with conn.cursor() as cur: ...`` to
index 8221bf2f894f79da51c2ba1abae5fcb4c9c80f5e..dd54b5c09ccb46fbd0640bc56f093e6668889043 100644 (file)
@@ -186,6 +186,18 @@ The `!ServerCursor` class
     documented the differences:
 
     .. autoattribute:: name
+    .. autoattribute:: scrollable
+
+       .. seealso:: The PostgreSQL DECLARE_ statement documetation
+          for the description of :sql:`[NO] SCROLL`.
+
+    .. autoattribute:: withhold
+
+       .. seealso:: The PostgreSQL DECLARE_ statement documetation
+          for the description of :sql:`{WITH|WITHOUT} HOLD`.
+
+    .. _DECLARE: https://www.postgresql.org/docs/current/sql-declare.html
+
 
     .. automethod:: close
 
@@ -196,18 +208,12 @@ The `!ServerCursor` class
             ...` pattern is especially useful so that the cursor is closed at
             the end of the block.
 
-    .. automethod:: execute(query, params=None, *, scrollable=None, withhold=False) -> ServerCursor
+    .. automethod:: execute(query, params=None, *) -> ServerCursor
 
         :param query: The query to execute.
         :type query: `!str`, `!bytes`, or `sql.Composable`
         :param params: The parameters to pass to the query, if any.
         :type params: Sequence or Mapping
-        :param scrollable: if `!True` make the cursor scrollable, if `!False`
-                           not. if `!None` leave the choice to the server.
-        :type scrollable: `!Optional[bool]`
-        :param withhold: if `!True` allow the cursor to be used after the
-                         transaction creating it has committed.
-        :type withhold: `!bool`
 
         Create a server cursor with given `name` and the *query* in argument.
         If using :sql:`DECLARE` is not appropriate you can avoid to use
@@ -217,11 +223,6 @@ The `!ServerCursor` class
         Using `!execute()` more than once will close the previous cursor and
         open a new one with the same name.
 
-        .. seealso:: The PostgreSQL DECLARE_ statement documetation describe
-            in details all the parameters.
-
-        .. _DECLARE: https://www.postgresql.org/docs/current/sql-declare.html
-
     .. automethod:: executemany(query: Query, params_seq: Sequence[Args])
 
     .. automethod:: fetchone
@@ -248,7 +249,7 @@ The `!ServerCursor` class
         This method uses the MOVE_ SQL statement to move the current position
         in the server-side cursor, which will affect following `!fetch*()`
         operations. If you need to scroll backwards you should probably
-        use `scrollable=True` in `execute()`.
+        call `~Connection.cursor()` using `scrollable=True`.
 
         Note that PostgreSQL doesn't provide a reliable way to report when a
         cursor moves out of bound, so the method might not raise `!IndexError`
@@ -314,7 +315,7 @@ The `!AsyncServerCursor` class
         .. note:: You can close the cursor automatically using :samp:`async
             with conn.cursor({name}): ...`
 
-    .. automethod:: execute(query, params=None, *, scrollable=None, withhold=False) -> AsyncServerCursor
+    .. automethod:: execute(query, params=None) -> AsyncServerCursor
     .. automethod:: executemany(query: Query, params_seq: Sequence[Args])
     .. automethod:: fetchone
     .. automethod:: fetchmany
index 6eb9d8991f4dba821c61c7d0d0c7fd09ceaf8804..2a8481eba4fc9b35b89d3745152f5747f3e8e3e1 100644 (file)
@@ -543,7 +543,14 @@ class Connection(BaseConnection[Row]):
         ...
 
     @overload
-    def cursor(self, name: str, *, binary: bool = False) -> ServerCursor[Row]:
+    def cursor(
+        self,
+        name: str,
+        *,
+        binary: bool = False,
+        scrollable: Optional[bool] = None,
+        withhold: bool = False,
+    ) -> ServerCursor[Row]:
         ...
 
     @overload
@@ -553,6 +560,8 @@ class Connection(BaseConnection[Row]):
         *,
         binary: bool = False,
         row_factory: RowFactory[CursorRow],
+        scrollable: Optional[bool] = None,
+        withhold: bool = False,
     ) -> ServerCursor[CursorRow]:
         ...
 
@@ -562,6 +571,8 @@ class Connection(BaseConnection[Row]):
         *,
         binary: bool = False,
         row_factory: Optional[RowFactory[Any]] = None,
+        scrollable: Optional[bool] = None,
+        withhold: bool = False,
     ) -> Union[Cursor[Any], ServerCursor[Any]]:
         """
         Return a new cursor to send commands and queries to the connection.
@@ -572,7 +583,11 @@ class Connection(BaseConnection[Row]):
         cur: Union[Cursor[Any], ServerCursor[Any]]
         if name:
             cur = self.server_cursor_factory(
-                self, name=name, row_factory=row_factory
+                self,
+                name=name,
+                row_factory=row_factory,
+                scrollable=scrollable,
+                withhold=withhold,
             )
         else:
             cur = self.cursor_factory(self, row_factory=row_factory)
@@ -764,7 +779,12 @@ class AsyncConnection(BaseConnection[Row]):
 
     @overload
     def cursor(
-        self, name: str, *, binary: bool = False
+        self,
+        name: str,
+        *,
+        binary: bool = False,
+        scrollable: Optional[bool] = None,
+        withhold: bool = False,
     ) -> AsyncServerCursor[Row]:
         ...
 
@@ -775,6 +795,8 @@ class AsyncConnection(BaseConnection[Row]):
         *,
         binary: bool = False,
         row_factory: RowFactory[CursorRow],
+        scrollable: Optional[bool] = None,
+        withhold: bool = False,
     ) -> AsyncServerCursor[CursorRow]:
         ...
 
@@ -784,6 +806,8 @@ class AsyncConnection(BaseConnection[Row]):
         *,
         binary: bool = False,
         row_factory: Optional[RowFactory[Any]] = None,
+        scrollable: Optional[bool] = None,
+        withhold: bool = False,
     ) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]:
         """
         Return a new `AsyncCursor` to send commands and queries to the connection.
@@ -794,7 +818,11 @@ class AsyncConnection(BaseConnection[Row]):
         cur: Union[AsyncCursor[Any], AsyncServerCursor[Any]]
         if name:
             cur = self.server_cursor_factory(
-                self, name=name, row_factory=row_factory
+                self,
+                name=name,
+                row_factory=row_factory,
+                scrollable=scrollable,
+                withhold=withhold,
             )
         else:
             cur = self.cursor_factory(self, row_factory=row_factory)
index 6638ec1fc85dfe3b95536b5a9e75f2f15b1bc522..2bf703eaad0cf428428f6592040e27506457dc20 100644 (file)
@@ -25,15 +25,22 @@ DEFAULT_ITERSIZE = 100
 
 
 class ServerCursorHelper(Generic[ConnectionType, Row]):
-    __slots__ = ("name", "described")
+    __slots__ = ("name", "scrollable", "withhold", "described")
     """Helper object for common ServerCursor code.
 
     TODO: this should be a mixin, but couldn't find a way to work it
     correctly with the generic.
     """
 
-    def __init__(self, name: str):
+    def __init__(
+        self,
+        name: str,
+        scrollable: Optional[bool],
+        withhold: bool,
+    ):
         self.name = name
+        self.scrollable = scrollable
+        self.withhold = withhold
         self.described = False
 
     def _repr(self, cur: BaseCursor[ConnectionType, Row]) -> str:
@@ -143,8 +150,6 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         self,
         cur: BaseCursor[ConnectionType, Row],
         query: Query,
-        scrollable: Optional[bool],
-        withhold: bool,
     ) -> sql.Composable:
 
         if isinstance(query, bytes):
@@ -156,10 +161,10 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
             sql.SQL("declare"),
             sql.Identifier(self.name),
         ]
-        if scrollable is not None:
-            parts.append(sql.SQL("scroll" if scrollable else "no scroll"))
+        if self.scrollable is not None:
+            parts.append(sql.SQL("scroll" if self.scrollable else "no scroll"))
         parts.append(sql.SQL("cursor"))
-        if withhold:
+        if self.withhold:
             parts.append(sql.SQL("with hold"))
         parts.append(sql.SQL("for"))
         parts.append(query)
@@ -177,10 +182,12 @@ class ServerCursor(BaseCursor["Connection[Any]", Row]):
         name: str,
         *,
         row_factory: RowFactory[Row],
+        scrollable: Optional[bool] = None,
+        withhold: bool = False,
     ):
         super().__init__(connection, row_factory=row_factory)
         self._helper: ServerCursorHelper["Connection[Any]", Row]
-        self._helper = ServerCursorHelper(name)
+        self._helper = ServerCursorHelper(name, scrollable, withhold)
         self.itersize: int = DEFAULT_ITERSIZE
 
     def __del__(self) -> None:
@@ -210,6 +217,23 @@ class ServerCursor(BaseCursor["Connection[Any]", Row]):
         """The name of the cursor."""
         return self._helper.name
 
+    @property
+    def scrollable(self) -> Optional[bool]:
+        """
+        Whether the cursor is scrollable or not.
+
+        If `!None` leave the choice to the server. Use `!True` if you want to
+        use `scroll()` on the cursor.
+        """
+        return self._helper.scrollable
+
+    @property
+    def withhold(self) -> bool:
+        """
+        If the cursor can be used after the creating transaction has committed.
+        """
+        return self._helper.withhold
+
     def close(self) -> None:
         """
         Close the current cursor and free associated resources.
@@ -222,16 +246,11 @@ class ServerCursor(BaseCursor["Connection[Any]", Row]):
         self,
         query: Query,
         params: Optional[Params] = None,
-        *,
-        scrollable: Optional[bool] = None,
-        withhold: bool = False,
     ) -> "ServerCursor[Row]":
         """
         Open a cursor to execute a query to the database.
         """
-        query = self._helper._make_declare_statement(
-            self, query, scrollable=scrollable, withhold=withhold
-        )
+        query = self._helper._make_declare_statement(self, query)
         with self._conn.lock:
             self._conn.wait(self._helper._declare_gen(self, query, params))
         return self
@@ -297,10 +316,12 @@ class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]):
         name: str,
         *,
         row_factory: RowFactory[Row],
+        scrollable: Optional[bool] = None,
+        withhold: bool = False,
     ):
         super().__init__(connection, row_factory=row_factory)
         self._helper: ServerCursorHelper["AsyncConnection[Any]", Row]
-        self._helper = ServerCursorHelper(name)
+        self._helper = ServerCursorHelper(name, scrollable, withhold)
         self.itersize: int = DEFAULT_ITERSIZE
 
     def __del__(self) -> None:
@@ -329,6 +350,14 @@ class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]):
     def name(self) -> str:
         return self._helper.name
 
+    @property
+    def scrollable(self) -> Optional[bool]:
+        return self._helper.scrollable
+
+    @property
+    def withhold(self) -> bool:
+        return self._helper.withhold
+
     async def close(self) -> None:
         async with self._conn.lock:
             await self._conn.wait(self._helper._close_gen(self))
@@ -338,13 +367,8 @@ class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]):
         self,
         query: Query,
         params: Optional[Params] = None,
-        *,
-        scrollable: Optional[bool] = None,
-        withhold: bool = False,
     ) -> "AsyncServerCursor[Row]":
-        query = self._helper._make_declare_statement(
-            self, query, scrollable=scrollable, withhold=withhold
-        )
+        query = self._helper._make_declare_statement(self, query)
         async with self._conn.lock:
             await self._conn.wait(
                 self._helper._declare_gen(self, query, params)
index 50fa3b4249e39f1243b814f9984f70ef99467702..6b2974802abd099cd899b74cfc1f37cb3c505200 100644 (file)
@@ -180,8 +180,8 @@ def test_row_factory(conn):
         n += 1
         return lambda values: [n] + [-v for v in values]
 
-    cur = conn.cursor("foo", row_factory=my_row_factory)
-    cur.execute("select generate_series(1, 3) as x", scrollable=True)
+    cur = conn.cursor("foo", row_factory=my_row_factory, scrollable=True)
+    cur.execute("select generate_series(1, 3) as x")
     rows = cur.fetchall()
     cur.scroll(0, "absolute")
     while 1:
@@ -247,12 +247,16 @@ def test_itersize(conn, commands):
             assert ("fetch forward 2") in cmd.lower()
 
 
-def test_scroll(conn):
+def test_cant_scroll_by_default(conn):
     cur = conn.cursor("tmp")
+    assert cur.scrollable is None
     with pytest.raises(e.ProgrammingError):
         cur.scroll(0)
 
-    cur.execute("select generate_series(0,9)", scrollable=True)
+
+def test_scroll(conn):
+    cur = conn.cursor("tmp", scrollable=True)
+    cur.execute("select generate_series(0,9)")
     cur.scroll(2)
     assert cur.fetchone() == (2,)
     cur.scroll(2)
@@ -267,8 +271,9 @@ def test_scroll(conn):
 
 
 def test_scrollable(conn):
-    curs = conn.cursor("foo")
-    curs.execute("select generate_series(0, 5)", scrollable=True)
+    curs = conn.cursor("foo", scrollable=True)
+    assert curs.scrollable is True
+    curs.execute("select generate_series(0, 5)")
     curs.scroll(5)
     for i in range(4, -1, -1):
         curs.scroll(-1)
@@ -277,8 +282,9 @@ def test_scrollable(conn):
 
 
 def test_non_scrollable(conn):
-    curs = conn.cursor("foo")
-    curs.execute("select generate_series(0, 5)", scrollable=False)
+    curs = conn.cursor("foo", scrollable=False)
+    assert curs.scrollable is False
+    curs.execute("select generate_series(0, 5)")
     curs.scroll(5)
     with pytest.raises(e.OperationalError):
         curs.scroll(-1)
@@ -287,16 +293,18 @@ def test_non_scrollable(conn):
 @pytest.mark.parametrize("kwargs", [{}, {"withhold": False}])
 def test_no_hold(conn, kwargs):
     with pytest.raises(e.InvalidCursorName):
-        with conn.cursor("foo") as curs:
-            curs.execute("select generate_series(0, 2)", **kwargs)
+        with conn.cursor("foo", **kwargs) as curs:
+            assert curs.withhold is False
+            curs.execute("select generate_series(0, 2)")
             assert curs.fetchone() == (0,)
             conn.commit()
             curs.fetchone()
 
 
 def test_hold(conn):
-    with conn.cursor("foo") as curs:
-        curs.execute("select generate_series(0, 5)", withhold=True)
+    with conn.cursor("foo", withhold=True) as curs:
+        assert curs.withhold is True
+        curs.execute("select generate_series(0, 5)")
         assert curs.fetchone() == (0,)
         conn.commit()
         assert curs.fetchone() == (1,)
index b0e4127bccf648506537c8166931c6d702601914..ca64590b90db126246ee490aebbd06e0fc307d17 100644 (file)
@@ -186,8 +186,8 @@ async def test_row_factory(aconn):
         n += 1
         return lambda values: [n] + [-v for v in values]
 
-    cur = aconn.cursor("foo", row_factory=my_row_factory)
-    await cur.execute("select generate_series(1, 3) as x", scrollable=True)
+    cur = aconn.cursor("foo", row_factory=my_row_factory, scrollable=True)
+    await cur.execute("select generate_series(1, 3) as x")
     rows = await cur.fetchall()
     await cur.scroll(0, "absolute")
     while 1:
@@ -258,12 +258,16 @@ async def test_itersize(aconn, acommands):
             assert ("fetch forward 2") in cmd.lower()
 
 
-async def test_scroll(aconn):
+async def test_cant_scroll_by_default(aconn):
     cur = aconn.cursor("tmp")
+    assert cur.scrollable is None
     with pytest.raises(e.ProgrammingError):
         await cur.scroll(0)
 
-    await cur.execute("select generate_series(0,9)", scrollable=True)
+
+async def test_scroll(aconn):
+    cur = aconn.cursor("tmp", scrollable=True)
+    await cur.execute("select generate_series(0,9)")
     await cur.scroll(2)
     assert await cur.fetchone() == (2,)
     await cur.scroll(2)
@@ -278,8 +282,9 @@ async def test_scroll(aconn):
 
 
 async def test_scrollable(aconn):
-    curs = aconn.cursor("foo")
-    await curs.execute("select generate_series(0, 5)", scrollable=True)
+    curs = aconn.cursor("foo", scrollable=True)
+    assert curs.scrollable is True
+    await curs.execute("select generate_series(0, 5)")
     await curs.scroll(5)
     for i in range(4, -1, -1):
         await curs.scroll(-1)
@@ -288,8 +293,9 @@ async def test_scrollable(aconn):
 
 
 async def test_non_scrollable(aconn):
-    curs = aconn.cursor("foo")
-    await curs.execute("select generate_series(0, 5)", scrollable=False)
+    curs = aconn.cursor("foo", scrollable=False)
+    assert curs.scrollable is False
+    await curs.execute("select generate_series(0, 5)")
     await curs.scroll(5)
     with pytest.raises(e.OperationalError):
         await curs.scroll(-1)
@@ -298,16 +304,18 @@ async def test_non_scrollable(aconn):
 @pytest.mark.parametrize("kwargs", [{}, {"withhold": False}])
 async def test_no_hold(aconn, kwargs):
     with pytest.raises(e.InvalidCursorName):
-        async with aconn.cursor("foo") as curs:
-            await curs.execute("select generate_series(0, 2)", **kwargs)
+        async with aconn.cursor("foo", **kwargs) as curs:
+            assert curs.withhold is False
+            await curs.execute("select generate_series(0, 2)")
             assert await curs.fetchone() == (0,)
             await aconn.commit()
             await curs.fetchone()
 
 
 async def test_hold(aconn):
-    async with aconn.cursor("foo") as curs:
-        await curs.execute("select generate_series(0, 5)", withhold=True)
+    async with aconn.cursor("foo", withhold=True) as curs:
+        assert curs.withhold is True
+        await curs.execute("select generate_series(0, 5)")
         assert await curs.fetchone() == (0,)
         await aconn.commit()
         assert await curs.fetchone() == (1,)