From: Daniele Varrazzo Date: Tue, 9 Feb 2021 20:04:00 +0000 (+0100) Subject: Add fetch methods to named cursors X-Git-Tag: 3.0.dev0~115^2~17 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d7d73e529906c6c6e5ccaac137f6b27ba57c8670;p=thirdparty%2Fpsycopg.git Add fetch methods to named cursors --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 7d097a601..c2a5a8436 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -321,12 +321,12 @@ class BaseConnection(AdaptContext): conn._autocommit = autocommit return conn - def _exec_command(self, command: Query) -> PQGen[None]: + def _exec_command(self, command: Query) -> PQGen["PGresult"]: """ Generator to send a command and receive the result to the backend. - Only used to implement internal commands such as commit, returning - no result. The cursor can do more complex stuff. + Only used to implement internal commands such as "commit", with eventual + arguments bound client-side. The cursor can do more complex stuff. """ if self.pgconn.status != ConnStatus.OK: if self.pgconn.status == ConnStatus.BAD: @@ -343,7 +343,7 @@ class BaseConnection(AdaptContext): self.pgconn.send_query(command) result = (yield from execute(self.pgconn))[-1] - if result.status != ExecStatus.COMMAND_OK: + if result.status not in (ExecStatus.COMMAND_OK, ExecStatus.TUPLES_OK): if result.status == ExecStatus.FATAL_ERROR: raise e.error_from_result( result, encoding=self.client_encoding @@ -353,6 +353,7 @@ class BaseConnection(AdaptContext): f"unexpected result {ExecStatus(result.status).name}" f" from command {command.decode('utf8')!r}" ) + return result def _start_query(self) -> PQGen[None]: """Generator to start a transaction if necessary.""" diff --git a/psycopg3/psycopg3/named_cursor.py b/psycopg3/psycopg3/named_cursor.py index 54e10e890..109b2bb23 100644 --- a/psycopg3/psycopg3/named_cursor.py +++ b/psycopg3/psycopg3/named_cursor.py @@ -7,7 +7,8 @@ psycopg3 named cursor objects (server-side cursors) import weakref import warnings from types import TracebackType -from typing import Any, Generic, Optional, Type, TYPE_CHECKING +from typing import Any, AsyncIterator, Generic, List, Iterator, Optional +from typing import Sequence, Type, Tuple, TYPE_CHECKING from . import sql from .pq import Format @@ -18,6 +19,8 @@ if TYPE_CHECKING: from .connection import BaseConnection # noqa: F401 from .connection import Connection, AsyncConnection # noqa: F401 +DEFAULT_ITERSIZE = 100 + class NamedCursorHelper(Generic[ConnectionType]): __slots__ = ("name", "_wcur") @@ -46,17 +49,18 @@ class NamedCursorHelper(Generic[ConnectionType]): ) -> PQGen[None]: """Generator implementing `NamedCursor.execute()`.""" cur = self._cur + conn = cur._conn yield from cur._start_query(query) pgq = cur._convert_query(query, params) cur._execute_send(pgq) - results = yield from execute(cur._conn.pgconn) + results = yield from execute(conn.pgconn) cur._execute_results(results) # The above result is an COMMAND_OK. Get the cursor result shape - cur._conn.pgconn.send_describe_portal( - self.name.encode(cur._conn.client_encoding) + conn.pgconn.send_describe_portal( + self.name.encode(conn.client_encoding) ) - results = yield from execute(cur._conn.pgconn) + results = yield from execute(conn.pgconn) cur._execute_results(results) def _close_gen(self) -> PQGen[None]: @@ -64,6 +68,24 @@ class NamedCursorHelper(Generic[ConnectionType]): query = sql.SQL("close {}").format(sql.Identifier(self.name)) yield from cur._conn._exec_command(query) + def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Tuple[Any, ...]]]: + if num is not None: + howmuch: sql.Composable = sql.Literal(num) + else: + howmuch = sql.SQL("all") + + cur = self._cur + query = sql.SQL("fetch forward {} from {}").format( + howmuch, sql.Identifier(self.name) + ) + res = yield from cur._conn._exec_command(query) + + # TODO: loaders don't need to be refreshed + cur.pgresult = res + nrows = res.ntuples + cur._pos += nrows + return cur._tx.load_rows(0, nrows) + def _make_declare_statement( self, query: Query, scrollable: bool, hold: bool ) -> sql.Composable: @@ -85,7 +107,7 @@ class NamedCursorHelper(Generic[ConnectionType]): class NamedCursor(BaseCursor["Connection"]): __module__ = "psycopg3" - __slots__ = ("_helper",) + __slots__ = ("_helper", "itersize") def __init__( self, @@ -96,6 +118,7 @@ class NamedCursor(BaseCursor["Connection"]): ): super().__init__(connection, format=format) self._helper = NamedCursorHelper(name, self) + self.itersize = DEFAULT_ITERSIZE def __del__(self) -> None: if not self._closed: @@ -146,10 +169,36 @@ class NamedCursor(BaseCursor["Connection"]): self._conn.wait(self._helper._declare_gen(query, params)) return self + def fetchone(self) -> Optional[Sequence[Any]]: + with self._conn.lock: + recs = self._conn.wait(self._helper._fetch_gen(1)) + return recs[0] if recs else None + + def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]: + if not size: + size = self.arraysize + with self._conn.lock: + recs = self._conn.wait(self._helper._fetch_gen(size)) + return recs + + def fetchall(self) -> Sequence[Sequence[Any]]: + with self._conn.lock: + recs = self._conn.wait(self._helper._fetch_gen(None)) + return recs + + def __iter__(self) -> Iterator[Sequence[Any]]: + while True: + with self._conn.lock: + recs = self._conn.wait(self._helper._fetch_gen(self.itersize)) + for rec in recs: + yield rec + if len(recs) < self.itersize: + break + class AsyncNamedCursor(BaseCursor["AsyncConnection"]): __module__ = "psycopg3" - __slots__ = ("_helper",) + __slots__ = ("_helper", "itersize") def __init__( self, @@ -160,6 +209,7 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]): ): super().__init__(connection, format=format) self._helper = NamedCursorHelper(name, self) + self.itersize = DEFAULT_ITERSIZE def __del__(self) -> None: if not self._closed: @@ -209,3 +259,31 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]): async with self._conn.lock: await self._conn.wait(self._helper._declare_gen(query, params)) return self + + async def fetchone(self) -> Optional[Sequence[Any]]: + async with self._conn.lock: + recs = await self._conn.wait(self._helper._fetch_gen(1)) + return recs[0] if recs else None + + async def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]: + if not size: + size = self.arraysize + async with self._conn.lock: + recs = await self._conn.wait(self._helper._fetch_gen(size)) + return recs + + async def fetchall(self) -> Sequence[Sequence[Any]]: + async with self._conn.lock: + recs = await self._conn.wait(self._helper._fetch_gen(None)) + return recs + + async def __aiter__(self) -> AsyncIterator[Sequence[Any]]: + while True: + async with self._conn.lock: + recs = await self._conn.wait( + self._helper._fetch_gen(self.itersize) + ) + for rec in recs: + yield rec + if len(recs) < self.itersize: + break diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py index a07ce6f73..650a4f1ef 100644 --- a/psycopg3/psycopg3/transaction.py +++ b/psycopg3/psycopg3/transaction.py @@ -13,6 +13,7 @@ from . import pq from . import sql from .pq import TransactionStatus from .proto import ConnectionType, PQGen +from .pq.proto import PGresult if TYPE_CHECKING: from .connection import Connection, AsyncConnection # noqa: F401 @@ -80,7 +81,7 @@ class BaseTransaction(Generic[ConnectionType]): sp = f"{self.savepoint_name!r} " if self.savepoint_name else "" return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>" - def _enter_gen(self) -> PQGen[None]: + def _enter_gen(self) -> PQGen[PGresult]: if self._entered: raise TypeError("transaction blocks can be used only once") self._entered = True @@ -126,7 +127,7 @@ class BaseTransaction(Generic[ConnectionType]): else: return (yield from self._rollback_gen(exc_val)) - def _commit_gen(self) -> PQGen[None]: + def _commit_gen(self) -> PQGen[PGresult]: assert self._conn._savepoints[-1] == self._savepoint_name self._conn._savepoints.pop() self._exited = True diff --git a/tests/fix_db.py b/tests/fix_db.py index 039bc1be0..1cab75592 100644 --- a/tests/fix_db.py +++ b/tests/fix_db.py @@ -65,3 +65,45 @@ def svcconn(dsn): conn = Connection.connect(dsn, autocommit=True) yield conn conn.close() + + +@pytest.fixture +def commands(conn, monkeypatch): + """The list of commands issued internally by the test connection.""" + yield patch_exec(conn, monkeypatch) + + +@pytest.fixture +def acommands(aconn, monkeypatch): + """The list of commands issued internally by the test async connection.""" + yield patch_exec(aconn, monkeypatch) + + +def patch_exec(conn, monkeypatch): + """Helper to implement the commands fixture both sync and async.""" + from psycopg3 import sql + + _orig_exec_command = conn._exec_command + L = ListPopAll() + + def _exec_command(command): + cmdcopy = command + if isinstance(cmdcopy, bytes): + cmdcopy = cmdcopy.decode(conn.client_encoding) + elif isinstance(cmdcopy, sql.Composable): + cmdcopy = cmdcopy.as_string(conn) + + L.insert(0, cmdcopy) + return _orig_exec_command(command) + + monkeypatch.setattr(conn, "_exec_command", _exec_command) + return L + + +class ListPopAll(list): + """A list, with a popall() method.""" + + def popall(self): + out = self[:] + del self[:] + return out diff --git a/tests/test_named_cursor.py b/tests/test_named_cursor.py index b9421a0cb..5dd89cbae 100644 --- a/tests/test_named_cursor.py +++ b/tests/test_named_cursor.py @@ -38,3 +38,77 @@ def test_warn_close(conn, recwarn): cur.execute("select generate_series(1, 10) as bar") del cur assert ".close()" in str(recwarn.pop(ResourceWarning).message) + + +def test_fetchone(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (2,)) + assert cur.fetchone() == (1,) + assert cur.fetchone() == (2,) + assert cur.fetchone() is None + + +def test_fetchmany(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (5,)) + assert cur.fetchmany(3) == [(1,), (2,), (3,)] + assert cur.fetchone() == (4,) + assert cur.fetchmany(3) == [(5,)] + assert cur.fetchmany(3) == [] + + +def test_fetchall(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + assert cur.fetchall() == [(1,), (2,), (3,)] + assert cur.fetchall() == [] + + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + assert cur.fetchone() == (1,) + assert cur.fetchall() == [(2,), (3,)] + assert cur.fetchall() == [] + + +def test_rownumber(conn): + cur = conn.cursor("foo") + assert cur.rownumber is None + + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + cur.fetchone() + assert cur.rownumber == 1 + cur.fetchone() + assert cur.rownumber == 2 + cur.fetchmany(10) + assert cur.rownumber == 12 + cur.fetchall() + assert cur.rownumber == 42 + + +def test_iter(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + recs = list(cur) + assert recs == [(1,), (2,), (3,)] + + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + assert cur.fetchone() == (1,) + recs = list(cur) + assert recs == [(2,), (3,)] + + +def test_itersize(conn, commands): + with conn.cursor("foo") as cur: + assert cur.itersize == 100 + cur.itersize = 2 + cur.execute("select generate_series(1, %s) as bar", (3,)) + commands.popall() # flush begin and other noise + + list(cur) + cmds = commands.popall() + assert len(cmds) == 2 + for cmd in cmds: + assert ("fetch forward 2") in cmd.lower() diff --git a/tests/test_named_cursor_async.py b/tests/test_named_cursor_async.py index 0ceeee114..d56951a73 100644 --- a/tests/test_named_cursor_async.py +++ b/tests/test_named_cursor_async.py @@ -43,3 +43,82 @@ async def test_warn_close(aconn, recwarn): await cur.execute("select generate_series(1, 10) as bar") del cur assert ".close()" in str(recwarn.pop(ResourceWarning).message) + + +async def test_fetchone(aconn): + async with await aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (2,)) + assert await cur.fetchone() == (1,) + assert await cur.fetchone() == (2,) + assert await cur.fetchone() is None + + +async def test_fetchmany(aconn): + async with await aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (5,)) + assert await cur.fetchmany(3) == [(1,), (2,), (3,)] + assert await cur.fetchone() == (4,) + assert await cur.fetchmany(3) == [(5,)] + assert await cur.fetchmany(3) == [] + + +async def test_fetchall(aconn): + async with await aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert await cur.fetchall() == [(1,), (2,), (3,)] + assert await cur.fetchall() == [] + + async with await aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert await cur.fetchone() == (1,) + assert await cur.fetchall() == [(2,), (3,)] + assert await cur.fetchall() == [] + + +async def test_rownumber(aconn): + cur = await aconn.cursor("foo") + assert cur.rownumber is None + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + await cur.fetchone() + assert cur.rownumber == 1 + await cur.fetchone() + assert cur.rownumber == 2 + await cur.fetchmany(10) + assert cur.rownumber == 12 + await cur.fetchall() + assert cur.rownumber == 42 + + +async def test_iter(aconn): + async with await aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + recs = [] + async for rec in cur: + recs.append(rec) + assert recs == [(1,), (2,), (3,)] + + async with await aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert await cur.fetchone() == (1,) + recs = [] + async for rec in cur: + recs.append(rec) + assert recs == [(2,), (3,)] + + +async def test_itersize(aconn, acommands): + async with await aconn.cursor("foo") as cur: + assert cur.itersize == 100 + cur.itersize = 2 + await cur.execute("select generate_series(1, %s) as bar", (3,)) + acommands.popall() # flush begin and other noise + + async for rec in cur: + pass + cmds = acommands.popall() + assert len(cmds) == 2 + for cmd in cmds: + assert ("fetch forward 2") in cmd.lower() diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 5f48edb92..48795531e 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -43,37 +43,6 @@ def inserted(conn): return f() -class ListPopAll(list): - """A list, with a popall() method.""" - - def popall(self): - out = self[:] - del self[:] - return out - - -@pytest.fixture -def commands(conn, monkeypatch): - """The list of commands issued internally by the test connection.""" - yield patch_exec(conn, monkeypatch) - - -def patch_exec(conn, monkeypatch): - """Helper to implement the commands fixture both sync and async.""" - _orig_exec_command = conn._exec_command - L = ListPopAll() - - def _exec_command(command): - if isinstance(command, bytes): - command = command.decode(conn.client_encoding) - - L.insert(0, command) - return _orig_exec_command(command) - - monkeypatch.setattr(conn, "_exec_command", _exec_command) - return L - - def in_transaction(conn): if conn.pgconn.transaction_status == conn.TransactionStatus.IDLE: return False diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py index 2bfab9301..c955039c9 100644 --- a/tests/test_transaction_async.py +++ b/tests/test_transaction_async.py @@ -3,18 +3,12 @@ import pytest from psycopg3 import ProgrammingError, Rollback from .test_transaction import in_transaction, insert_row, inserted -from .test_transaction import ExpectedException, patch_exec +from .test_transaction import ExpectedException from .test_transaction import create_test_table # noqa # autouse fixture pytestmark = pytest.mark.asyncio -@pytest.fixture -def commands(aconn, monkeypatch): - """The list of commands issued internally by the test connection.""" - yield patch_exec(aconn, monkeypatch) - - async def test_basic(aconn): """Basic use of transaction() to BEGIN and COMMIT a transaction.""" assert not in_transaction(aconn) @@ -310,7 +304,7 @@ async def test_named_savepoint_escapes_savepoint_name(aconn): pass -async def test_named_savepoints_successful_exit(aconn, commands): +async def test_named_savepoints_successful_exit(aconn, acommands): """ Entering a transaction context will do one of these these things: 1. Begin an outer transaction (if one isn't already in progress) @@ -320,6 +314,8 @@ async def test_named_savepoints_successful_exit(aconn, commands): ...and exiting the context successfully will "commit" the same. """ + commands = acommands + # Case 1 # Using Transaction explicitly becase conn.transaction() enters the contetx async with aconn.transaction() as tx: @@ -363,12 +359,14 @@ async def test_named_savepoints_successful_exit(aconn, commands): assert commands.popall() == ["commit"] -async def test_named_savepoints_exception_exit(aconn, commands): +async def test_named_savepoints_exception_exit(aconn, acommands): """ Same as the previous test but checks that when exiting the context with an exception, whatever transaction and/or savepoint was started on enter will be rolled-back as appropriate. """ + commands = acommands + # Case 1 with pytest.raises(ExpectedException): async with aconn.transaction() as tx: