]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fixed rowcount with DML queries, multiple results, executemany
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 29 Oct 2020 00:51:03 +0000 (01:51 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 29 Oct 2020 03:23:26 +0000 (04:23 +0100)
psycopg3/psycopg3/cursor.py
tests/test_cursor.py
tests/test_cursor_async.py

index 31ccc3fa3a17a3ad699303cecce083e7d16d24fa..b21eca483c5a508d64be244eb62565559162d46f 100644 (file)
@@ -88,6 +88,7 @@ class BaseCursor:
         self.pgresult = None
         self._pos = 0
         self._iresult = 0
+        self._rowcount = -1
 
     @property
     def closed(self) -> bool:
@@ -124,11 +125,7 @@ class BaseCursor:
 
     @property
     def rowcount(self) -> int:
-        res = self.pgresult
-        if res is None or res.status != self.ExecStatus.TUPLES_OK:
-            return -1
-        else:
-            return res.ntuples
+        return self._rowcount
 
     def setinputsizes(self, sizes: Sequence[Any]) -> None:
         # no-op
@@ -191,6 +188,13 @@ class BaseCursor:
         if not badstats:
             self._results = list(results)
             self.pgresult = results[0]
+            nrows = self.pgresult.command_tuples
+            if nrows is not None:
+                if self._rowcount < 0:
+                    self._rowcount = nrows
+                else:
+                    self._rowcount += nrows
+
             return
 
         if results[-1].status == S.FATAL_ERROR:
@@ -236,6 +240,8 @@ class BaseCursor:
         if self._iresult < len(self._results):
             self.pgresult = self._results[self._iresult]
             self._pos = 0
+            nrows = self.pgresult.command_tuples
+            self._rowcount = nrows if nrows is not None else -1
             return True
         else:
             return None
index b874af29406c1c83227fd25a224d8f84a2437cfd..9a9038dca8ac856391c3067589a5bda347ff6fcd 100644 (file)
@@ -49,12 +49,12 @@ def test_execute_many_results(conn):
     cur = conn.cursor()
     assert cur.nextset() is None
 
-    rv = cur.execute("select 'foo'; select 'bar'")
+    rv = cur.execute("select 'foo'; select generate_series(1,3)")
     assert rv is cur
-    assert len(cur._results) == 2
-    assert cur.pgresult.get_value(0, 0) == b"foo"
+    assert cur.fetchall() == [("foo",)]
+    assert cur.rowcount == 1
     assert cur.nextset()
-    assert cur.pgresult.get_value(0, 0) == b"bar"
+    assert cur.fetchall() == [(1,), (2,), (3,)]
     assert cur.nextset() is None
 
     cur.close()
@@ -158,7 +158,6 @@ def test_executemany_name(conn, execmany):
     assert cur.fetchall() == [(11, "hello"), (21, "world")]
 
 
-@pytest.mark.xfail
 def test_executemany_rowcount(conn, execmany):
     cur = conn.cursor()
     cur.executemany(
@@ -180,3 +179,20 @@ def test_executemany_badquery(conn, query):
     cur = conn.cursor()
     with pytest.raises(psycopg3.DatabaseError):
         cur.executemany(query, [(10, "hello"), (20, "world")])
+
+
+def test_rowcount(conn):
+    cur = conn.cursor()
+    cur.execute("select 1 from generate_series(1, 42)")
+    assert cur.rowcount == 42
+
+    cur.execute("create table test_rowcount_notuples (id int primary key)")
+    assert cur.rowcount == -1
+
+    cur.execute(
+        "insert into test_rowcount_notuples select generate_series(1, 42)"
+    )
+    assert cur.rowcount == 42
+
+    cur.close()
+    assert cur.rowcount == -1
index 53c3cf2a32e08afc719ece54a730ccc115441eb8..dabada93f4ae020a5f7c7f76d77c62dee75a895a 100644 (file)
@@ -51,12 +51,13 @@ async def test_execute_many_results(aconn):
     cur = aconn.cursor()
     assert cur.nextset() is None
 
-    rv = await cur.execute("select 'foo'; select 'bar'")
+    rv = await cur.execute("select 'foo'; select generate_series(1,3)")
     assert rv is cur
-    assert len(cur._results) == 2
-    assert cur.pgresult.get_value(0, 0) == b"foo"
+    assert (await cur.fetchall()) == [("foo",)]
+    assert cur.rowcount == 1
     assert cur.nextset()
-    assert cur.pgresult.get_value(0, 0) == b"bar"
+    assert (await cur.fetchall()) == [(1,), (2,), (3,)]
+    assert cur.rowcount == 3
     assert cur.nextset() is None
 
     await cur.close()
@@ -157,7 +158,6 @@ async def test_executemany_name(aconn, execmany):
     assert rv == [(11, "hello"), (21, "world")]
 
 
-@pytest.mark.xfail
 async def test_executemany_rowcount(aconn, execmany):
     cur = aconn.cursor()
     await cur.executemany(
@@ -179,3 +179,23 @@ async def test_executemany_badquery(aconn, query):
     cur = aconn.cursor()
     with pytest.raises(psycopg3.DatabaseError):
         await cur.executemany(query, [(10, "hello"), (20, "world")])
+
+
+async def test_rowcount(aconn):
+    cur = aconn.cursor()
+
+    await cur.execute("select 1 from generate_series(1, 42)")
+    assert cur.rowcount == 42
+
+    await cur.execute(
+        "create table test_rowcount_notuples (id int primary key)"
+    )
+    assert cur.rowcount == -1
+
+    await cur.execute(
+        "insert into test_rowcount_notuples select generate_series(1, 42)"
+    )
+    assert cur.rowcount == 42
+
+    await cur.close()
+    assert cur.rowcount == -1