From: Daniele Varrazzo Date: Tue, 9 Feb 2021 03:22:01 +0000 (+0100) Subject: Added classes for named cursors X-Git-Tag: 3.0.dev0~115^2~20 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=decae5e4346dbf7c196f5bdb5433c1f4310adbfe;p=thirdparty%2Fpsycopg.git Added classes for named cursors Only execute implemented, with a describe roundtrip to get the portal description. --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 0727d415c..7d097a601 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -11,7 +11,7 @@ import warnings import threading from types import TracebackType from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple -from typing import Optional, Type, TYPE_CHECKING +from typing import Optional, overload, Type, Union, TYPE_CHECKING from weakref import ref, ReferenceType from functools import partial from contextlib import contextmanager @@ -23,7 +23,6 @@ else: from . import pq from . import adapt -from . import cursor from . import errors as e from . import waiting from . import encodings @@ -31,9 +30,11 @@ from .pq import ConnStatus, ExecStatus, TransactionStatus, Format from .sql import Composable from .proto import PQGen, PQGenConn, RV, Query, Params, AdaptContext from .proto import ConnectionType +from .cursor import Cursor, AsyncCursor from .conninfo import make_conninfo from .generators import notifies from .transaction import Transaction, AsyncTransaction +from .named_cursor import NamedCursor, AsyncNamedCursor from ._preparing import PrepareManager logger = logging.getLogger(__name__) @@ -43,7 +44,6 @@ connect: Callable[[str], PQGenConn["PGconn"]] execute: Callable[["PGconn"], PQGen[List["PGresult"]]] if TYPE_CHECKING: - from .cursor import AsyncCursor, BaseCursor, Cursor from .pq.proto import PGconn, PGresult if pq.__impl__ == "c": @@ -102,8 +102,6 @@ class BaseConnection(AdaptContext): ConnStatus = pq.ConnStatus TransactionStatus = pq.TransactionStatus - cursor_factory: Type["BaseCursor[Any]"] - def __init__(self, pgconn: "PGconn"): self.pgconn = pgconn # TODO: document this self._autocommit = False @@ -400,12 +398,9 @@ class Connection(BaseConnection): __module__ = "psycopg3" - cursor_factory: Type["Cursor"] - def __init__(self, pgconn: "PGconn"): super().__init__(pgconn) self.lock = threading.Lock() - self.cursor_factory = cursor.Cursor @classmethod def connect( @@ -448,22 +443,32 @@ class Connection(BaseConnection): """Close the database connection.""" self.pgconn.finish() - def cursor(self, name: str = "", binary: bool = False) -> "Cursor": + @overload + def cursor(self, *, binary: bool = False) -> Cursor: + ... + + @overload + def cursor(self, name: str, *, binary: bool = False) -> NamedCursor: + ... + + def cursor( + self, name: str = "", *, binary: bool = False + ) -> Union[Cursor, NamedCursor]: """ Return a new `Cursor` to send commands and queries to the connection. """ - if name: - raise NotImplementedError - format = Format.BINARY if binary else Format.TEXT - return self.cursor_factory(self, format=format) + if name: + return NamedCursor(self, name=name, format=format) + else: + return Cursor(self, format=format) def execute( self, query: Query, params: Optional[Params] = None, prepare: Optional[bool] = None, - ) -> "Cursor": + ) -> Cursor: """Execute a query and return a cursor to read its results.""" cur = self.cursor() return cur.execute(query, params, prepare=prepare) @@ -541,12 +546,9 @@ class AsyncConnection(BaseConnection): __module__ = "psycopg3" - cursor_factory: Type["AsyncCursor"] - def __init__(self, pgconn: "PGconn"): super().__init__(pgconn) self.lock = asyncio.Lock() - self.cursor_factory = cursor.AsyncCursor @classmethod async def connect( @@ -583,24 +585,34 @@ class AsyncConnection(BaseConnection): async def close(self) -> None: self.pgconn.finish() + @overload + async def cursor(self, *, binary: bool = False) -> AsyncCursor: + ... + + @overload + async def cursor( + self, name: str, *, binary: bool = False + ) -> AsyncNamedCursor: + ... + async def cursor( - self, name: str = "", binary: bool = False - ) -> "AsyncCursor": + self, name: str = "", *, binary: bool = False + ) -> Union[AsyncCursor, AsyncNamedCursor]: """ Return a new `AsyncCursor` to send commands and queries to the connection. """ - if name: - raise NotImplementedError - format = Format.BINARY if binary else Format.TEXT - return self.cursor_factory(self, format=format) + if name: + return AsyncNamedCursor(self, name=name, format=format) + else: + return AsyncCursor(self, format=format) async def execute( self, query: Query, params: Optional[Params] = None, prepare: Optional[bool] = None, - ) -> "AsyncCursor": + ) -> AsyncCursor: cur = await self.cursor() return await cur.execute(query, params, prepare=prepare) diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index abc5d2cd4..b6028195d 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -30,6 +30,7 @@ else: if TYPE_CHECKING: from .proto import Transformer from .pq.proto import PGconn, PGresult + from .connection import BaseConnection # noqa: F401 from .connection import Connection, AsyncConnection # noqa: F401 execute: Callable[["PGconn"], PQGen[List["PGresult"]]] @@ -58,9 +59,7 @@ class BaseCursor(Generic[ConnectionType]): _tx: "Transformer" def __init__( - self, - connection: ConnectionType, - format: Format = Format.TEXT, + self, connection: ConnectionType, *, format: Format = Format.TEXT ): self._conn = connection self.format = format @@ -138,7 +137,7 @@ class BaseCursor(Generic[ConnectionType]): `!None` if the current resultset didn't return tuples. """ res = self.pgresult - if not res or res.status != ExecStatus.TUPLES_OK: + if not (res and res.nfields): return None return [Column(self, i) for i in range(res.nfields)] @@ -184,12 +183,14 @@ class BaseCursor(Generic[ConnectionType]): self, query: Query, params: Optional[Params] = None, + *, prepare: Optional[bool] = None, ) -> PQGen[None]: """Generator implementing `Cursor.execute()`.""" yield from self._start_query(query) pgq = self._convert_query(query, params) - yield from self._maybe_prepare_gen(pgq, prepare) + results = yield from self._maybe_prepare_gen(pgq, prepare) + self._execute_results(results) self._last_query = query def _executemany_gen( @@ -206,13 +207,14 @@ class BaseCursor(Generic[ConnectionType]): else: pgq.dump(params) - yield from self._maybe_prepare_gen(pgq, True) + results = yield from self._maybe_prepare_gen(pgq, True) + self._execute_results(results) self._last_query = query def _maybe_prepare_gen( self, pgq: PostgresQuery, prepare: Optional[bool] - ) -> PQGen[None]: + ) -> PQGen[Sequence["PGresult"]]: # Check if the query is prepared or needs preparing prep, name = self._conn._prepared.get(pgq, prepare) if prep is Prepare.YES: @@ -242,7 +244,7 @@ class BaseCursor(Generic[ConnectionType]): if cmd: yield from self._conn._exec_command(cmd) - self._execute_results(results) + return results def _stream_send_gen( self, query: Query, params: Optional[Params] = None @@ -429,6 +431,14 @@ class BaseCursor(Generic[ConnectionType]): f" FROM STDIN statements, got {ExecStatus(status).name}" ) + def _close(self) -> None: + self._closed = True + # however keep the query available, which can be useful for debugging + # in case of errors + pgq = self._pgq + self._reset() + self._pgq = pgq + class Cursor(BaseCursor["Connection"]): __module__ = "psycopg3" @@ -449,17 +459,13 @@ class Cursor(BaseCursor["Connection"]): """ Close the current cursor and free associated resources. """ - self._closed = True - # however keep the query available, which can be useful for debugging - # in case of errors - pgq = self._pgq - self._reset() - self._pgq = pgq + self._close() def execute( self, query: Query, params: Optional[Params] = None, + *, prepare: Optional[bool] = None, ) -> "Cursor": """ @@ -568,13 +574,13 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): await self.close() async def close(self) -> None: - self._closed = True - self._reset() + self._close() async def execute( self, query: Query, params: Optional[Params] = None, + *, prepare: Optional[bool] = None, ) -> "AsyncCursor": async with self._conn.lock: @@ -644,15 +650,3 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): async with AsyncCopy(self) as copy: yield copy - - -class NamedCursorMixin: - pass - - -class NamedCursor(NamedCursorMixin, Cursor): - pass - - -class AsyncNamedCursor(NamedCursorMixin, AsyncCursor): - pass diff --git a/psycopg3/psycopg3/named_cursor.py b/psycopg3/psycopg3/named_cursor.py new file mode 100644 index 000000000..d0344b183 --- /dev/null +++ b/psycopg3/psycopg3/named_cursor.py @@ -0,0 +1,199 @@ +""" +psycopg3 named cursor objects (server-side cursors) +""" + +# Copyright (C) 2020-2021 The Psycopg Team + +import weakref +import warnings +from types import TracebackType +from typing import Any, Generic, Optional, Type, TYPE_CHECKING + +from . import sql +from .pq import Format +from .cursor import BaseCursor, execute +from .proto import ConnectionType, Query, Params, PQGen + +if TYPE_CHECKING: + from .connection import BaseConnection # noqa: F401 + from .connection import Connection, AsyncConnection # noqa: F401 + + +class NamedCursorHelper(Generic[ConnectionType]): + __slots__ = ("name", "_wcur") + + def __init__( + self, + name: str, + cursor: BaseCursor[ConnectionType], + ): + self.name = name + self._wcur = weakref.ref(cursor) + + @property + def _cur(self) -> BaseCursor[Any]: + cur = self._wcur() + assert cur + return cur + + def _declare_gen( + self, query: Query, params: Optional[Params] = None + ) -> PQGen[None]: + """Generator implementing `NamedCursor.execute()`.""" + cur = self._cur + yield from cur._start_query(query) + pgq = cur._convert_query(query, params) + cur._execute_send(pgq) + results = yield from execute(cur._conn.pgconn) + cur._execute_results(results) + + # The above result is an COMMAND_OK. Get the cursor result shape + cur._conn.pgconn.send_describe_portal( + self.name.encode(cur._conn.client_encoding) + ) + results = yield from execute(cur._conn.pgconn) + cur._execute_results(results) + + def _make_declare_statement( + self, query: Query, scrollable: bool, hold: bool + ) -> sql.Composable: + cur = self._cur + if isinstance(query, bytes): + query = query.decode(cur._conn.client_encoding) + if not isinstance(query, sql.Composable): + query = sql.SQL(query) + + return sql.SQL( + "declare {name} {scroll} cursor{hold} for {query}" + ).format( + name=sql.Identifier(self.name), + scroll=sql.SQL("scroll" if scrollable else "no scroll"), + hold=sql.SQL(" with hold" if hold else ""), + query=query, + ) + + +class NamedCursor(BaseCursor["Connection"]): + __module__ = "psycopg3" + __slots__ = ("_helper",) + + def __init__( + self, + connection: "Connection", + name: str, + *, + format: Format = Format.TEXT, + ): + super().__init__(connection, format=format) + self._helper = NamedCursorHelper(name, self) + + def __del__(self) -> None: + if not self._closed: + warnings.warn( + f"named cursor {self} was deleted while still open." + f" Please use 'with' or '.close()' to close the cursor properly", + ResourceWarning, + ) + + def __enter__(self) -> "NamedCursor": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + + @property + def name(self) -> str: + return self._helper.name + + def close(self) -> None: + """ + Close the current cursor and free associated resources. + """ + # TODO close the cursor for real + self._close() + + def execute( + self, + query: Query, + params: Optional[Params] = None, + *, + scrollable: bool = True, + hold: bool = False, + ) -> "NamedCursor": + """ + Execute a query or command to the database. + """ + query = self._helper._make_declare_statement( + query, scrollable=scrollable, hold=hold + ) + with self._conn.lock: + self._conn.wait(self._helper._declare_gen(query, params)) + return self + + +class AsyncNamedCursor(BaseCursor["AsyncConnection"]): + __module__ = "psycopg3" + __slots__ = ("_helper",) + + def __init__( + self, + connection: "AsyncConnection", + name: str, + *, + format: Format = Format.TEXT, + ): + super().__init__(connection, format=format) + self._helper = NamedCursorHelper(name, self) + + def __del__(self) -> None: + if not self._closed: + warnings.warn( + f"named cursor {self} was deleted while still open." + f" Please use 'with' or '.close()' to close the cursor properly", + ResourceWarning, + ) + + async def __aenter__(self) -> "AsyncNamedCursor": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() + + @property + def name(self) -> str: + return self._helper.name + + async def close(self) -> None: + """ + Close the current cursor and free associated resources. + """ + # TODO close the cursor for real + self._close() + + async def execute( + self, + query: Query, + params: Optional[Params] = None, + *, + scrollable: bool = True, + hold: bool = False, + ) -> "AsyncNamedCursor": + """ + Execute a query or command to the database. + """ + query = self._helper._make_declare_statement( + query, scrollable=scrollable, hold=hold + ) + async with self._conn.lock: + await self._conn.wait(self._helper._declare_gen(query, params)) + return self diff --git a/psycopg3/psycopg3/pq/proto.py b/psycopg3/psycopg3/pq/proto.py index a62651959..53c85c2ae 100644 --- a/psycopg3/psycopg3/pq/proto.py +++ b/psycopg3/psycopg3/pq/proto.py @@ -191,9 +191,15 @@ class PGconn(Protocol): def describe_prepared(self, name: bytes) -> "PGresult": ... + def send_describe_prepared(self, name: bytes) -> None: + ... + def describe_portal(self, name: bytes) -> "PGresult": ... + def send_describe_portal(self, name: bytes) -> None: + ... + def get_result(self) -> Optional["PGresult"]: ... diff --git a/tests/test_named_cursor.py b/tests/test_named_cursor.py new file mode 100644 index 000000000..82ae83f6c --- /dev/null +++ b/tests/test_named_cursor.py @@ -0,0 +1,8 @@ +def test_description(conn): + cur = conn.cursor("foo") + assert cur.name == "foo" + cur.execute("select generate_series(1, 10) as bar") + assert len(cur.description) == 1 + assert cur.description[0].name == "bar" + assert cur.description[0].type_code == cur.adapters.types["int4"].oid + assert cur.pgresult.ntuples == 0 diff --git a/tests/test_named_cursor_async.py b/tests/test_named_cursor_async.py new file mode 100644 index 000000000..538be22e9 --- /dev/null +++ b/tests/test_named_cursor_async.py @@ -0,0 +1,13 @@ +import pytest + +pytestmark = pytest.mark.asyncio + + +async def test_description(aconn): + cur = await aconn.cursor("foo") + assert cur.name == "foo" + await cur.execute("select generate_series(1, 10) as bar") + assert len(cur.description) == 1 + assert cur.description[0].name == "bar" + assert cur.description[0].type_code == cur.adapters.types["int4"].oid + assert cur.pgresult.ntuples == 0