]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added classes for named cursors
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 9 Feb 2021 03:22:01 +0000 (04:22 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 9 Feb 2021 03:22:01 +0000 (04:22 +0100)
Only execute implemented, with a describe roundtrip to get the
portal description.

psycopg3/psycopg3/connection.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/named_cursor.py [new file with mode: 0644]
psycopg3/psycopg3/pq/proto.py
tests/test_named_cursor.py [new file with mode: 0644]
tests/test_named_cursor_async.py [new file with mode: 0644]

index 0727d415c94f298dbcd603413c61622cdde8ab7c..7d097a60135e8d0f724a39b36a180543712d1593 100644 (file)
@@ -11,7 +11,7 @@ import warnings
 import threading
 from types import TracebackType
 from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
-from typing import Optional, Type, TYPE_CHECKING
+from typing import Optional, overload, Type, Union, TYPE_CHECKING
 from weakref import ref, ReferenceType
 from functools import partial
 from contextlib import contextmanager
@@ -23,7 +23,6 @@ else:
 
 from . import pq
 from . import adapt
-from . import cursor
 from . import errors as e
 from . import waiting
 from . import encodings
@@ -31,9 +30,11 @@ from .pq import ConnStatus, ExecStatus, TransactionStatus, Format
 from .sql import Composable
 from .proto import PQGen, PQGenConn, RV, Query, Params, AdaptContext
 from .proto import ConnectionType
+from .cursor import Cursor, AsyncCursor
 from .conninfo import make_conninfo
 from .generators import notifies
 from .transaction import Transaction, AsyncTransaction
+from .named_cursor import NamedCursor, AsyncNamedCursor
 from ._preparing import PrepareManager
 
 logger = logging.getLogger(__name__)
@@ -43,7 +44,6 @@ connect: Callable[[str], PQGenConn["PGconn"]]
 execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
 
 if TYPE_CHECKING:
-    from .cursor import AsyncCursor, BaseCursor, Cursor
     from .pq.proto import PGconn, PGresult
 
 if pq.__impl__ == "c":
@@ -102,8 +102,6 @@ class BaseConnection(AdaptContext):
     ConnStatus = pq.ConnStatus
     TransactionStatus = pq.TransactionStatus
 
-    cursor_factory: Type["BaseCursor[Any]"]
-
     def __init__(self, pgconn: "PGconn"):
         self.pgconn = pgconn  # TODO: document this
         self._autocommit = False
@@ -400,12 +398,9 @@ class Connection(BaseConnection):
 
     __module__ = "psycopg3"
 
-    cursor_factory: Type["Cursor"]
-
     def __init__(self, pgconn: "PGconn"):
         super().__init__(pgconn)
         self.lock = threading.Lock()
-        self.cursor_factory = cursor.Cursor
 
     @classmethod
     def connect(
@@ -448,22 +443,32 @@ class Connection(BaseConnection):
         """Close the database connection."""
         self.pgconn.finish()
 
-    def cursor(self, name: str = "", binary: bool = False) -> "Cursor":
+    @overload
+    def cursor(self, *, binary: bool = False) -> Cursor:
+        ...
+
+    @overload
+    def cursor(self, name: str, *, binary: bool = False) -> NamedCursor:
+        ...
+
+    def cursor(
+        self, name: str = "", *, binary: bool = False
+    ) -> Union[Cursor, NamedCursor]:
         """
         Return a new `Cursor` to send commands and queries to the connection.
         """
-        if name:
-            raise NotImplementedError
-
         format = Format.BINARY if binary else Format.TEXT
-        return self.cursor_factory(self, format=format)
+        if name:
+            return NamedCursor(self, name=name, format=format)
+        else:
+            return Cursor(self, format=format)
 
     def execute(
         self,
         query: Query,
         params: Optional[Params] = None,
         prepare: Optional[bool] = None,
-    ) -> "Cursor":
+    ) -> Cursor:
         """Execute a query and return a cursor to read its results."""
         cur = self.cursor()
         return cur.execute(query, params, prepare=prepare)
@@ -541,12 +546,9 @@ class AsyncConnection(BaseConnection):
 
     __module__ = "psycopg3"
 
-    cursor_factory: Type["AsyncCursor"]
-
     def __init__(self, pgconn: "PGconn"):
         super().__init__(pgconn)
         self.lock = asyncio.Lock()
-        self.cursor_factory = cursor.AsyncCursor
 
     @classmethod
     async def connect(
@@ -583,24 +585,34 @@ class AsyncConnection(BaseConnection):
     async def close(self) -> None:
         self.pgconn.finish()
 
+    @overload
+    async def cursor(self, *, binary: bool = False) -> AsyncCursor:
+        ...
+
+    @overload
+    async def cursor(
+        self, name: str, *, binary: bool = False
+    ) -> AsyncNamedCursor:
+        ...
+
     async def cursor(
-        self, name: str = "", binary: bool = False
-    ) -> "AsyncCursor":
+        self, name: str = "", *, binary: bool = False
+    ) -> Union[AsyncCursor, AsyncNamedCursor]:
         """
         Return a new `AsyncCursor` to send commands and queries to the connection.
         """
-        if name:
-            raise NotImplementedError
-
         format = Format.BINARY if binary else Format.TEXT
-        return self.cursor_factory(self, format=format)
+        if name:
+            return AsyncNamedCursor(self, name=name, format=format)
+        else:
+            return AsyncCursor(self, format=format)
 
     async def execute(
         self,
         query: Query,
         params: Optional[Params] = None,
         prepare: Optional[bool] = None,
-    ) -> "AsyncCursor":
+    ) -> AsyncCursor:
         cur = await self.cursor()
         return await cur.execute(query, params, prepare=prepare)
 
index abc5d2cd4d51040401fb8ef26131458c0aba1d5d..b6028195da25134fe5735caf43204c482be6cce8 100644 (file)
@@ -30,6 +30,7 @@ else:
 if TYPE_CHECKING:
     from .proto import Transformer
     from .pq.proto import PGconn, PGresult
+    from .connection import BaseConnection  # noqa: F401
     from .connection import Connection, AsyncConnection  # noqa: F401
 
 execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
@@ -58,9 +59,7 @@ class BaseCursor(Generic[ConnectionType]):
     _tx: "Transformer"
 
     def __init__(
-        self,
-        connection: ConnectionType,
-        format: Format = Format.TEXT,
+        self, connection: ConnectionType, *, format: Format = Format.TEXT
     ):
         self._conn = connection
         self.format = format
@@ -138,7 +137,7 @@ class BaseCursor(Generic[ConnectionType]):
         `!None` if the current resultset didn't return tuples.
         """
         res = self.pgresult
-        if not res or res.status != ExecStatus.TUPLES_OK:
+        if not (res and res.nfields):
             return None
         return [Column(self, i) for i in range(res.nfields)]
 
@@ -184,12 +183,14 @@ class BaseCursor(Generic[ConnectionType]):
         self,
         query: Query,
         params: Optional[Params] = None,
+        *,
         prepare: Optional[bool] = None,
     ) -> PQGen[None]:
         """Generator implementing `Cursor.execute()`."""
         yield from self._start_query(query)
         pgq = self._convert_query(query, params)
-        yield from self._maybe_prepare_gen(pgq, prepare)
+        results = yield from self._maybe_prepare_gen(pgq, prepare)
+        self._execute_results(results)
         self._last_query = query
 
     def _executemany_gen(
@@ -206,13 +207,14 @@ class BaseCursor(Generic[ConnectionType]):
             else:
                 pgq.dump(params)
 
-            yield from self._maybe_prepare_gen(pgq, True)
+            results = yield from self._maybe_prepare_gen(pgq, True)
+            self._execute_results(results)
 
         self._last_query = query
 
     def _maybe_prepare_gen(
         self, pgq: PostgresQuery, prepare: Optional[bool]
-    ) -> PQGen[None]:
+    ) -> PQGen[Sequence["PGresult"]]:
         # Check if the query is prepared or needs preparing
         prep, name = self._conn._prepared.get(pgq, prepare)
         if prep is Prepare.YES:
@@ -242,7 +244,7 @@ class BaseCursor(Generic[ConnectionType]):
             if cmd:
                 yield from self._conn._exec_command(cmd)
 
-        self._execute_results(results)
+        return results
 
     def _stream_send_gen(
         self, query: Query, params: Optional[Params] = None
@@ -429,6 +431,14 @@ class BaseCursor(Generic[ConnectionType]):
                 f" FROM STDIN statements, got {ExecStatus(status).name}"
             )
 
+    def _close(self) -> None:
+        self._closed = True
+        # however keep the query available, which can be useful for debugging
+        # in case of errors
+        pgq = self._pgq
+        self._reset()
+        self._pgq = pgq
+
 
 class Cursor(BaseCursor["Connection"]):
     __module__ = "psycopg3"
@@ -449,17 +459,13 @@ class Cursor(BaseCursor["Connection"]):
         """
         Close the current cursor and free associated resources.
         """
-        self._closed = True
-        # however keep the query available, which can be useful for debugging
-        # in case of errors
-        pgq = self._pgq
-        self._reset()
-        self._pgq = pgq
+        self._close()
 
     def execute(
         self,
         query: Query,
         params: Optional[Params] = None,
+        *,
         prepare: Optional[bool] = None,
     ) -> "Cursor":
         """
@@ -568,13 +574,13 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
         await self.close()
 
     async def close(self) -> None:
-        self._closed = True
-        self._reset()
+        self._close()
 
     async def execute(
         self,
         query: Query,
         params: Optional[Params] = None,
+        *,
         prepare: Optional[bool] = None,
     ) -> "AsyncCursor":
         async with self._conn.lock:
@@ -644,15 +650,3 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
 
         async with AsyncCopy(self) as copy:
             yield copy
-
-
-class NamedCursorMixin:
-    pass
-
-
-class NamedCursor(NamedCursorMixin, Cursor):
-    pass
-
-
-class AsyncNamedCursor(NamedCursorMixin, AsyncCursor):
-    pass
diff --git a/psycopg3/psycopg3/named_cursor.py b/psycopg3/psycopg3/named_cursor.py
new file mode 100644 (file)
index 0000000..d0344b1
--- /dev/null
@@ -0,0 +1,199 @@
+"""
+psycopg3 named cursor objects (server-side cursors)
+"""
+
+# Copyright (C) 2020-2021 The Psycopg Team
+
+import weakref
+import warnings
+from types import TracebackType
+from typing import Any, Generic, Optional, Type, TYPE_CHECKING
+
+from . import sql
+from .pq import Format
+from .cursor import BaseCursor, execute
+from .proto import ConnectionType, Query, Params, PQGen
+
+if TYPE_CHECKING:
+    from .connection import BaseConnection  # noqa: F401
+    from .connection import Connection, AsyncConnection  # noqa: F401
+
+
+class NamedCursorHelper(Generic[ConnectionType]):
+    __slots__ = ("name", "_wcur")
+
+    def __init__(
+        self,
+        name: str,
+        cursor: BaseCursor[ConnectionType],
+    ):
+        self.name = name
+        self._wcur = weakref.ref(cursor)
+
+    @property
+    def _cur(self) -> BaseCursor[Any]:
+        cur = self._wcur()
+        assert cur
+        return cur
+
+    def _declare_gen(
+        self, query: Query, params: Optional[Params] = None
+    ) -> PQGen[None]:
+        """Generator implementing `NamedCursor.execute()`."""
+        cur = self._cur
+        yield from cur._start_query(query)
+        pgq = cur._convert_query(query, params)
+        cur._execute_send(pgq)
+        results = yield from execute(cur._conn.pgconn)
+        cur._execute_results(results)
+
+        # The above result is an COMMAND_OK. Get the cursor result shape
+        cur._conn.pgconn.send_describe_portal(
+            self.name.encode(cur._conn.client_encoding)
+        )
+        results = yield from execute(cur._conn.pgconn)
+        cur._execute_results(results)
+
+    def _make_declare_statement(
+        self, query: Query, scrollable: bool, hold: bool
+    ) -> sql.Composable:
+        cur = self._cur
+        if isinstance(query, bytes):
+            query = query.decode(cur._conn.client_encoding)
+        if not isinstance(query, sql.Composable):
+            query = sql.SQL(query)
+
+        return sql.SQL(
+            "declare {name} {scroll} cursor{hold} for {query}"
+        ).format(
+            name=sql.Identifier(self.name),
+            scroll=sql.SQL("scroll" if scrollable else "no scroll"),
+            hold=sql.SQL(" with hold" if hold else ""),
+            query=query,
+        )
+
+
+class NamedCursor(BaseCursor["Connection"]):
+    __module__ = "psycopg3"
+    __slots__ = ("_helper",)
+
+    def __init__(
+        self,
+        connection: "Connection",
+        name: str,
+        *,
+        format: Format = Format.TEXT,
+    ):
+        super().__init__(connection, format=format)
+        self._helper = NamedCursorHelper(name, self)
+
+    def __del__(self) -> None:
+        if not self._closed:
+            warnings.warn(
+                f"named cursor {self} was deleted while still open."
+                f" Please use 'with' or '.close()' to close the cursor properly",
+                ResourceWarning,
+            )
+
+    def __enter__(self) -> "NamedCursor":
+        return self
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        self.close()
+
+    @property
+    def name(self) -> str:
+        return self._helper.name
+
+    def close(self) -> None:
+        """
+        Close the current cursor and free associated resources.
+        """
+        # TODO close the cursor for real
+        self._close()
+
+    def execute(
+        self,
+        query: Query,
+        params: Optional[Params] = None,
+        *,
+        scrollable: bool = True,
+        hold: bool = False,
+    ) -> "NamedCursor":
+        """
+        Execute a query or command to the database.
+        """
+        query = self._helper._make_declare_statement(
+            query, scrollable=scrollable, hold=hold
+        )
+        with self._conn.lock:
+            self._conn.wait(self._helper._declare_gen(query, params))
+        return self
+
+
+class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
+    __module__ = "psycopg3"
+    __slots__ = ("_helper",)
+
+    def __init__(
+        self,
+        connection: "AsyncConnection",
+        name: str,
+        *,
+        format: Format = Format.TEXT,
+    ):
+        super().__init__(connection, format=format)
+        self._helper = NamedCursorHelper(name, self)
+
+    def __del__(self) -> None:
+        if not self._closed:
+            warnings.warn(
+                f"named cursor {self} was deleted while still open."
+                f" Please use 'with' or '.close()' to close the cursor properly",
+                ResourceWarning,
+            )
+
+    async def __aenter__(self) -> "AsyncNamedCursor":
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        await self.close()
+
+    @property
+    def name(self) -> str:
+        return self._helper.name
+
+    async def close(self) -> None:
+        """
+        Close the current cursor and free associated resources.
+        """
+        # TODO close the cursor for real
+        self._close()
+
+    async def execute(
+        self,
+        query: Query,
+        params: Optional[Params] = None,
+        *,
+        scrollable: bool = True,
+        hold: bool = False,
+    ) -> "AsyncNamedCursor":
+        """
+        Execute a query or command to the database.
+        """
+        query = self._helper._make_declare_statement(
+            query, scrollable=scrollable, hold=hold
+        )
+        async with self._conn.lock:
+            await self._conn.wait(self._helper._declare_gen(query, params))
+        return self
index a626519592f7c429c231a3f9692aee80f8894b8c..53c85c2ae1588455a63965facac6d565b1cf6a71 100644 (file)
@@ -191,9 +191,15 @@ class PGconn(Protocol):
     def describe_prepared(self, name: bytes) -> "PGresult":
         ...
 
+    def send_describe_prepared(self, name: bytes) -> None:
+        ...
+
     def describe_portal(self, name: bytes) -> "PGresult":
         ...
 
+    def send_describe_portal(self, name: bytes) -> None:
+        ...
+
     def get_result(self) -> Optional["PGresult"]:
         ...
 
diff --git a/tests/test_named_cursor.py b/tests/test_named_cursor.py
new file mode 100644 (file)
index 0000000..82ae83f
--- /dev/null
@@ -0,0 +1,8 @@
+def test_description(conn):
+    cur = conn.cursor("foo")
+    assert cur.name == "foo"
+    cur.execute("select generate_series(1, 10) as bar")
+    assert len(cur.description) == 1
+    assert cur.description[0].name == "bar"
+    assert cur.description[0].type_code == cur.adapters.types["int4"].oid
+    assert cur.pgresult.ntuples == 0
diff --git a/tests/test_named_cursor_async.py b/tests/test_named_cursor_async.py
new file mode 100644 (file)
index 0000000..538be22
--- /dev/null
@@ -0,0 +1,13 @@
+import pytest
+
+pytestmark = pytest.mark.asyncio
+
+
+async def test_description(aconn):
+    cur = await aconn.cursor("foo")
+    assert cur.name == "foo"
+    await cur.execute("select generate_series(1, 10) as bar")
+    assert len(cur.description) == 1
+    assert cur.description[0].name == "bar"
+    assert cur.description[0].type_code == cur.adapters.types["int4"].oid
+    assert cur.pgresult.ntuples == 0