]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Default server-side cursor scrollable to not defined
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 11 Feb 2021 18:35:38 +0000 (19:35 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 11 Feb 2021 18:35:38 +0000 (19:35 +0100)
psycopg3/psycopg3/server_cursor.py
tests/test_server_cursor.py
tests/test_server_cursor_async.py

index 2ded325cf6fb2a45d0a7536c368ccd849d7810e3..88f7bad0e180a30d21818115df91ea6132eace34 100644 (file)
@@ -138,22 +138,28 @@ class ServerCursorHelper(Generic[ConnectionType]):
         self,
         cur: BaseCursor[ConnectionType],
         query: Query,
-        scrollable: bool,
+        scrollable: Optional[bool],
         hold: bool,
     ) -> sql.Composable:
+
         if isinstance(query, bytes):
             query = query.decode(cur._conn.client_encoding)
         if not isinstance(query, sql.Composable):
             query = sql.SQL(query)
 
-        return sql.SQL(
-            "declare {name} {scroll} cursor{hold} for {query}"
-        ).format(
-            name=sql.Identifier(self.name),
-            scroll=sql.SQL("scroll" if scrollable else "no scroll"),
-            hold=sql.SQL(" with hold" if hold else ""),
-            query=query,
-        )
+        parts = [
+            sql.SQL("declare"),
+            sql.Identifier(self.name),
+        ]
+        if scrollable is not None:
+            parts.append(sql.SQL("scroll" if scrollable else "no scroll"))
+        parts.append(sql.SQL("cursor"))
+        if hold:
+            parts.append(sql.SQL("with hold"))
+        parts.append(sql.SQL("for"))
+        parts.append(query)
+
+        return sql.SQL(" ").join(parts)
 
 
 class ServerCursor(BaseCursor["Connection"]):
@@ -213,7 +219,7 @@ class ServerCursor(BaseCursor["Connection"]):
         query: Query,
         params: Optional[Params] = None,
         *,
-        scrollable: bool = True,
+        scrollable: Optional[bool] = None,
         hold: bool = False,
     ) -> "ServerCursor":
         """
@@ -329,7 +335,7 @@ class AsyncServerCursor(BaseCursor["AsyncConnection"]):
         query: Query,
         params: Optional[Params] = None,
         *,
-        scrollable: bool = True,
+        scrollable: Optional[bool] = None,
         hold: bool = False,
     ) -> "AsyncServerCursor":
         query = self._helper._make_declare_statement(
index 78d456507ae804bb673a33c507e73049cc747cc3..09ae7bac4d1c3b213a00e15c2616ac08a22dcb04 100644 (file)
@@ -206,7 +206,7 @@ def test_scroll(conn):
     with pytest.raises(e.ProgrammingError):
         cur.scroll(0)
 
-    cur.execute("select generate_series(0,9)")
+    cur.execute("select generate_series(0,9)", scrollable=True)
     cur.scroll(2)
     assert cur.fetchone() == (2,)
     cur.scroll(2)
@@ -222,7 +222,7 @@ def test_scroll(conn):
 
 def test_scrollable(conn):
     curs = conn.cursor("foo")
-    curs.execute("select generate_series(0, 5)")
+    curs.execute("select generate_series(0, 5)", scrollable=True)
     curs.scroll(5)
     for i in range(4, -1, -1):
         curs.scroll(-1)
index 4a34655351b2b72b5600495dd85d4b84f30723ac..8326578d6805688af748cae606be0de833d3fae5 100644 (file)
@@ -213,7 +213,7 @@ async def test_scroll(aconn):
     with pytest.raises(e.ProgrammingError):
         await cur.scroll(0)
 
-    await cur.execute("select generate_series(0,9)")
+    await cur.execute("select generate_series(0,9)", scrollable=True)
     await cur.scroll(2)
     assert await cur.fetchone() == (2,)
     await cur.scroll(2)
@@ -229,7 +229,7 @@ async def test_scroll(aconn):
 
 async def test_scrollable(aconn):
     curs = aconn.cursor("foo")
-    await curs.execute("select generate_series(0, 5)")
+    await curs.execute("select generate_series(0, 5)", scrollable=True)
     await curs.scroll(5)
     for i in range(4, -1, -1):
         await curs.scroll(-1)