]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add test for all the documented NamedCursor interface
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 10 Feb 2021 14:30:27 +0000 (15:30 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 10 Feb 2021 14:30:27 +0000 (15:30 +0100)
psycopg3/psycopg3/named_cursor.py
tests/test_named_cursor.py
tests/test_named_cursor_async.py

index 6e6ec71eeb9676a398133a3ccd45d433dc9fa010..abb3d1cd6a12aba275287395701c739bb3193f4e 100644 (file)
@@ -10,6 +10,7 @@ from typing import Any, AsyncIterator, Generic, List, Iterator, Optional
 from typing import Sequence, Type, Tuple, TYPE_CHECKING
 
 from . import sql
+from . import errors as e
 from .pq import Format
 from .cursor import BaseCursor, execute
 from .proto import ConnectionType, Query, Params, PQGen
@@ -195,6 +196,9 @@ class NamedCursor(BaseCursor["Connection"]):
             self._conn.wait(self._helper._declare_gen(self, query, params))
         return self
 
+    def executemany(self, query: Query, params_seq: Sequence[Params]) -> None:
+        raise e.NotSupportedError("executemany not supported on named cursors")
+
     def fetchone(self) -> Optional[Sequence[Any]]:
         with self._conn.lock:
             recs = self._conn.wait(self._helper._fetch_gen(self, 1))
@@ -307,6 +311,11 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
             )
         return self
 
+    async def executemany(
+        self, query: Query, params_seq: Sequence[Params]
+    ) -> None:
+        raise e.NotSupportedError("executemany not supported on named cursors")
+
     async def fetchone(self) -> Optional[Sequence[Any]]:
         async with self._conn.lock:
             recs = await self._conn.wait(self._helper._fetch_gen(self, 1))
index e3bbca10a390d0fd8f83d38431a8dd8efac5f5cc..cbbf23904c1827ab43ac156dc048f46e98286472 100644 (file)
@@ -1,5 +1,7 @@
 import pytest
 
+from psycopg3.pq import Format
+
 
 def test_funny_name(conn):
     cur = conn.cursor("1-2-3")
@@ -8,6 +10,11 @@ def test_funny_name(conn):
     assert cur.name == "1-2-3"
 
 
+def test_connection(conn):
+    cur = conn.cursor("foo")
+    assert cur.connection is conn
+
+
 def test_description(conn):
     cur = conn.cursor("foo")
     assert cur.name == "foo"
@@ -18,6 +25,24 @@ def test_description(conn):
     assert cur.pgresult.ntuples == 0
 
 
+def test_format(conn):
+    cur = conn.cursor("foo")
+    assert cur.format == Format.TEXT
+
+    cur = conn.cursor("foo", binary=True)
+    assert cur.format == Format.BINARY
+
+
+def test_query_params(conn):
+    with conn.cursor("foo") as cur:
+        assert cur.query is None
+        assert cur.params is None
+        cur.execute("select generate_series(1, %s) as bar", (3,))
+        assert b"declare" in cur.query.lower()
+        assert b"(1, $1)" in cur.query.lower()
+        assert cur.params == [bytes([0, 3])]  # 3 as binary int2
+
+
 def test_close(conn, recwarn):
     cur = conn.cursor("foo")
     cur.execute("select generate_series(1, 10) as bar")
@@ -56,6 +81,12 @@ def test_warn_close(conn, recwarn):
     assert ".close()" in str(recwarn.pop(ResourceWarning).message)
 
 
+def test_executemany(conn):
+    cur = conn.cursor("foo")
+    with pytest.raises(conn.NotSupportedError):
+        cur.executemany("select %s", [(1,), (2,)])
+
+
 def test_fetchone(conn):
     with conn.cursor("foo") as cur:
         cur.execute("select generate_series(1, %s) as bar", (2,))
@@ -86,6 +117,12 @@ def test_fetchall(conn):
         assert cur.fetchall() == []
 
 
+def test_nextset(conn):
+    with conn.cursor("foo") as cur:
+        cur.execute("select generate_series(1, %s) as bar", (3,))
+        assert not cur.nextset()
+
+
 def test_rownumber(conn):
     cur = conn.cursor("foo")
     assert cur.rownumber is None
index 7d6fa49b1a12885dadb71db051a315c5a12e9c84..79bb35338c87a03d3bf40f3fd0e89e9e633c9e18 100644 (file)
@@ -1,5 +1,7 @@
 import pytest
 
+from psycopg3.pq import Format
+
 pytestmark = pytest.mark.asyncio
 
 
@@ -10,6 +12,11 @@ async def test_funny_name(aconn):
     assert cur.name == "1-2-3"
 
 
+async def test_connection(aconn):
+    cur = aconn.cursor("foo")
+    assert cur.connection is aconn
+
+
 async def test_description(aconn):
     cur = aconn.cursor("foo")
     assert cur.name == "foo"
@@ -20,6 +27,24 @@ async def test_description(aconn):
     assert cur.pgresult.ntuples == 0
 
 
+async def test_format(aconn):
+    cur = aconn.cursor("foo")
+    assert cur.format == Format.TEXT
+
+    cur = aconn.cursor("foo", binary=True)
+    assert cur.format == Format.BINARY
+
+
+async def test_query_params(aconn):
+    async with aconn.cursor("foo") as cur:
+        assert cur.query is None
+        assert cur.params is None
+        await cur.execute("select generate_series(1, %s) as bar", (3,))
+        assert b"declare" in cur.query.lower()
+        assert b"(1, $1)" in cur.query.lower()
+        assert cur.params == [bytes([0, 3])]  # 3 as binary int2
+
+
 async def test_close(aconn, recwarn):
     cur = aconn.cursor("foo")
     await cur.execute("select generate_series(1, 10) as bar")
@@ -58,6 +83,12 @@ async def test_warn_close(aconn, recwarn):
     assert ".close()" in str(recwarn.pop(ResourceWarning).message)
 
 
+async def test_executemany(aconn):
+    cur = aconn.cursor("foo")
+    with pytest.raises(aconn.NotSupportedError):
+        await cur.executemany("select %s", [(1,), (2,)])
+
+
 async def test_fetchone(aconn):
     async with aconn.cursor("foo") as cur:
         await cur.execute("select generate_series(1, %s) as bar", (2,))
@@ -88,6 +119,12 @@ async def test_fetchall(aconn):
         assert await cur.fetchall() == []
 
 
+async def test_nextset(aconn):
+    async with aconn.cursor("foo") as cur:
+        await cur.execute("select generate_series(1, %s) as bar", (3,))
+        assert not cur.nextset()
+
+
 async def test_rownumber(aconn):
     cur = aconn.cursor("foo")
     assert cur.rownumber is None
@@ -192,3 +229,14 @@ async def test_steal_cursor(aconn):
     assert await cur2.fetchone() == (1,)
     assert await cur2.fetchmany(3) == [(2,), (3,), (4,)]
     assert await cur2.fetchall() == [(5,), (6,)]
+
+
+async def test_stolen_cursor_close(aconn):
+    cur1 = aconn.cursor()
+    await cur1.execute("declare test cursor for select generate_series(1, 6)")
+    cur2 = aconn.cursor("test")
+    await cur2.close()
+
+    await cur1.execute("declare test cursor for select generate_series(1, 6)")
+    cur2 = aconn.cursor("test")
+    await cur2.close()