From 46f3050913a71dee62d8d9ea25afa564913d6bf7 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Thu, 29 Oct 2020 01:51:03 +0100 Subject: [PATCH] Fixed rowcount with DML queries, multiple results, executemany --- psycopg3/psycopg3/cursor.py | 16 +++++++++++----- tests/test_cursor.py | 26 +++++++++++++++++++++----- tests/test_cursor_async.py | 30 +++++++++++++++++++++++++----- 3 files changed, 57 insertions(+), 15 deletions(-) diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index 31ccc3fa3..b21eca483 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -88,6 +88,7 @@ class BaseCursor: self.pgresult = None self._pos = 0 self._iresult = 0 + self._rowcount = -1 @property def closed(self) -> bool: @@ -124,11 +125,7 @@ class BaseCursor: @property def rowcount(self) -> int: - res = self.pgresult - if res is None or res.status != self.ExecStatus.TUPLES_OK: - return -1 - else: - return res.ntuples + return self._rowcount def setinputsizes(self, sizes: Sequence[Any]) -> None: # no-op @@ -191,6 +188,13 @@ class BaseCursor: if not badstats: self._results = list(results) self.pgresult = results[0] + nrows = self.pgresult.command_tuples + if nrows is not None: + if self._rowcount < 0: + self._rowcount = nrows + else: + self._rowcount += nrows + return if results[-1].status == S.FATAL_ERROR: @@ -236,6 +240,8 @@ class BaseCursor: if self._iresult < len(self._results): self.pgresult = self._results[self._iresult] self._pos = 0 + nrows = self.pgresult.command_tuples + self._rowcount = nrows if nrows is not None else -1 return True else: return None diff --git a/tests/test_cursor.py b/tests/test_cursor.py index b874af294..9a9038dca 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -49,12 +49,12 @@ def test_execute_many_results(conn): cur = conn.cursor() assert cur.nextset() is None - rv = cur.execute("select 'foo'; select 'bar'") + rv = cur.execute("select 'foo'; select generate_series(1,3)") assert rv is cur - assert len(cur._results) == 2 - assert cur.pgresult.get_value(0, 0) == b"foo" + assert cur.fetchall() == [("foo",)] + assert cur.rowcount == 1 assert cur.nextset() - assert cur.pgresult.get_value(0, 0) == b"bar" + assert cur.fetchall() == [(1,), (2,), (3,)] assert cur.nextset() is None cur.close() @@ -158,7 +158,6 @@ def test_executemany_name(conn, execmany): assert cur.fetchall() == [(11, "hello"), (21, "world")] -@pytest.mark.xfail def test_executemany_rowcount(conn, execmany): cur = conn.cursor() cur.executemany( @@ -180,3 +179,20 @@ def test_executemany_badquery(conn, query): cur = conn.cursor() with pytest.raises(psycopg3.DatabaseError): cur.executemany(query, [(10, "hello"), (20, "world")]) + + +def test_rowcount(conn): + cur = conn.cursor() + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rowcount == 42 + + cur.execute("create table test_rowcount_notuples (id int primary key)") + assert cur.rowcount == -1 + + cur.execute( + "insert into test_rowcount_notuples select generate_series(1, 42)" + ) + assert cur.rowcount == 42 + + cur.close() + assert cur.rowcount == -1 diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index 53c3cf2a3..dabada93f 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -51,12 +51,13 @@ async def test_execute_many_results(aconn): cur = aconn.cursor() assert cur.nextset() is None - rv = await cur.execute("select 'foo'; select 'bar'") + rv = await cur.execute("select 'foo'; select generate_series(1,3)") assert rv is cur - assert len(cur._results) == 2 - assert cur.pgresult.get_value(0, 0) == b"foo" + assert (await cur.fetchall()) == [("foo",)] + assert cur.rowcount == 1 assert cur.nextset() - assert cur.pgresult.get_value(0, 0) == b"bar" + assert (await cur.fetchall()) == [(1,), (2,), (3,)] + assert cur.rowcount == 3 assert cur.nextset() is None await cur.close() @@ -157,7 +158,6 @@ async def test_executemany_name(aconn, execmany): assert rv == [(11, "hello"), (21, "world")] -@pytest.mark.xfail async def test_executemany_rowcount(aconn, execmany): cur = aconn.cursor() await cur.executemany( @@ -179,3 +179,23 @@ async def test_executemany_badquery(aconn, query): cur = aconn.cursor() with pytest.raises(psycopg3.DatabaseError): await cur.executemany(query, [(10, "hello"), (20, "world")]) + + +async def test_rowcount(aconn): + cur = aconn.cursor() + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rowcount == 42 + + await cur.execute( + "create table test_rowcount_notuples (id int primary key)" + ) + assert cur.rowcount == -1 + + await cur.execute( + "insert into test_rowcount_notuples select generate_series(1, 42)" + ) + assert cur.rowcount == 42 + + await cur.close() + assert cur.rowcount == -1 -- 2.47.2