From: Daniele Varrazzo Date: Thu, 11 Feb 2021 18:35:38 +0000 (+0100) Subject: Default server-side cursor scrollable to not defined X-Git-Tag: 3.0.dev0~115^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=10a7523e002e853cd7ef39e0a516baff18bdd776;p=thirdparty%2Fpsycopg.git Default server-side cursor scrollable to not defined --- diff --git a/psycopg3/psycopg3/server_cursor.py b/psycopg3/psycopg3/server_cursor.py index 2ded325cf..88f7bad0e 100644 --- a/psycopg3/psycopg3/server_cursor.py +++ b/psycopg3/psycopg3/server_cursor.py @@ -138,22 +138,28 @@ class ServerCursorHelper(Generic[ConnectionType]): self, cur: BaseCursor[ConnectionType], query: Query, - scrollable: bool, + scrollable: Optional[bool], hold: bool, ) -> sql.Composable: + if isinstance(query, bytes): query = query.decode(cur._conn.client_encoding) if not isinstance(query, sql.Composable): query = sql.SQL(query) - return sql.SQL( - "declare {name} {scroll} cursor{hold} for {query}" - ).format( - name=sql.Identifier(self.name), - scroll=sql.SQL("scroll" if scrollable else "no scroll"), - hold=sql.SQL(" with hold" if hold else ""), - query=query, - ) + parts = [ + sql.SQL("declare"), + sql.Identifier(self.name), + ] + if scrollable is not None: + parts.append(sql.SQL("scroll" if scrollable else "no scroll")) + parts.append(sql.SQL("cursor")) + if hold: + parts.append(sql.SQL("with hold")) + parts.append(sql.SQL("for")) + parts.append(query) + + return sql.SQL(" ").join(parts) class ServerCursor(BaseCursor["Connection"]): @@ -213,7 +219,7 @@ class ServerCursor(BaseCursor["Connection"]): query: Query, params: Optional[Params] = None, *, - scrollable: bool = True, + scrollable: Optional[bool] = None, hold: bool = False, ) -> "ServerCursor": """ @@ -329,7 +335,7 @@ class AsyncServerCursor(BaseCursor["AsyncConnection"]): query: Query, params: Optional[Params] = None, *, - scrollable: bool = True, + scrollable: Optional[bool] = None, hold: bool = False, ) -> "AsyncServerCursor": query = self._helper._make_declare_statement( diff --git a/tests/test_server_cursor.py b/tests/test_server_cursor.py index 78d456507..09ae7bac4 100644 --- a/tests/test_server_cursor.py +++ b/tests/test_server_cursor.py @@ -206,7 +206,7 @@ def test_scroll(conn): with pytest.raises(e.ProgrammingError): cur.scroll(0) - cur.execute("select generate_series(0,9)") + cur.execute("select generate_series(0,9)", scrollable=True) cur.scroll(2) assert cur.fetchone() == (2,) cur.scroll(2) @@ -222,7 +222,7 @@ def test_scroll(conn): def test_scrollable(conn): curs = conn.cursor("foo") - curs.execute("select generate_series(0, 5)") + curs.execute("select generate_series(0, 5)", scrollable=True) curs.scroll(5) for i in range(4, -1, -1): curs.scroll(-1) diff --git a/tests/test_server_cursor_async.py b/tests/test_server_cursor_async.py index 4a3465535..8326578d6 100644 --- a/tests/test_server_cursor_async.py +++ b/tests/test_server_cursor_async.py @@ -213,7 +213,7 @@ async def test_scroll(aconn): with pytest.raises(e.ProgrammingError): await cur.scroll(0) - await cur.execute("select generate_series(0,9)") + await cur.execute("select generate_series(0,9)", scrollable=True) await cur.scroll(2) assert await cur.fetchone() == (2,) await cur.scroll(2) @@ -229,7 +229,7 @@ async def test_scroll(aconn): async def test_scrollable(aconn): curs = aconn.cursor("foo") - await curs.execute("select generate_series(0, 5)") + await curs.execute("select generate_series(0, 5)", scrollable=True) await curs.scroll(5) for i in range(4, -1, -1): await curs.scroll(-1)