import threading
from types import TracebackType
from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
-from typing import Optional, Type, TYPE_CHECKING
+from typing import Optional, overload, Type, Union, TYPE_CHECKING
from weakref import ref, ReferenceType
from functools import partial
from contextlib import contextmanager
from . import pq
from . import adapt
-from . import cursor
from . import errors as e
from . import waiting
from . import encodings
from .sql import Composable
from .proto import PQGen, PQGenConn, RV, Query, Params, AdaptContext
from .proto import ConnectionType
+from .cursor import Cursor, AsyncCursor
from .conninfo import make_conninfo
from .generators import notifies
from .transaction import Transaction, AsyncTransaction
+from .named_cursor import NamedCursor, AsyncNamedCursor
from ._preparing import PrepareManager
logger = logging.getLogger(__name__)
execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
if TYPE_CHECKING:
- from .cursor import AsyncCursor, BaseCursor, Cursor
from .pq.proto import PGconn, PGresult
if pq.__impl__ == "c":
ConnStatus = pq.ConnStatus
TransactionStatus = pq.TransactionStatus
- cursor_factory: Type["BaseCursor[Any]"]
-
def __init__(self, pgconn: "PGconn"):
self.pgconn = pgconn # TODO: document this
self._autocommit = False
__module__ = "psycopg3"
- cursor_factory: Type["Cursor"]
-
def __init__(self, pgconn: "PGconn"):
super().__init__(pgconn)
self.lock = threading.Lock()
- self.cursor_factory = cursor.Cursor
@classmethod
def connect(
"""Close the database connection."""
self.pgconn.finish()
- def cursor(self, name: str = "", binary: bool = False) -> "Cursor":
+ @overload
+ def cursor(self, *, binary: bool = False) -> Cursor:
+ ...
+
+ @overload
+ def cursor(self, name: str, *, binary: bool = False) -> NamedCursor:
+ ...
+
+ def cursor(
+ self, name: str = "", *, binary: bool = False
+ ) -> Union[Cursor, NamedCursor]:
"""
Return a new `Cursor` to send commands and queries to the connection.
"""
- if name:
- raise NotImplementedError
-
format = Format.BINARY if binary else Format.TEXT
- return self.cursor_factory(self, format=format)
+ if name:
+ return NamedCursor(self, name=name, format=format)
+ else:
+ return Cursor(self, format=format)
def execute(
self,
query: Query,
params: Optional[Params] = None,
prepare: Optional[bool] = None,
- ) -> "Cursor":
+ ) -> Cursor:
"""Execute a query and return a cursor to read its results."""
cur = self.cursor()
return cur.execute(query, params, prepare=prepare)
__module__ = "psycopg3"
- cursor_factory: Type["AsyncCursor"]
-
def __init__(self, pgconn: "PGconn"):
super().__init__(pgconn)
self.lock = asyncio.Lock()
- self.cursor_factory = cursor.AsyncCursor
@classmethod
async def connect(
async def close(self) -> None:
self.pgconn.finish()
+ @overload
+ async def cursor(self, *, binary: bool = False) -> AsyncCursor:
+ ...
+
+ @overload
+ async def cursor(
+ self, name: str, *, binary: bool = False
+ ) -> AsyncNamedCursor:
+ ...
+
async def cursor(
- self, name: str = "", binary: bool = False
- ) -> "AsyncCursor":
+ self, name: str = "", *, binary: bool = False
+ ) -> Union[AsyncCursor, AsyncNamedCursor]:
"""
Return a new `AsyncCursor` to send commands and queries to the connection.
"""
- if name:
- raise NotImplementedError
-
format = Format.BINARY if binary else Format.TEXT
- return self.cursor_factory(self, format=format)
+ if name:
+ return AsyncNamedCursor(self, name=name, format=format)
+ else:
+ return AsyncCursor(self, format=format)
async def execute(
self,
query: Query,
params: Optional[Params] = None,
prepare: Optional[bool] = None,
- ) -> "AsyncCursor":
+ ) -> AsyncCursor:
cur = await self.cursor()
return await cur.execute(query, params, prepare=prepare)
if TYPE_CHECKING:
from .proto import Transformer
from .pq.proto import PGconn, PGresult
+ from .connection import BaseConnection # noqa: F401
from .connection import Connection, AsyncConnection # noqa: F401
execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
_tx: "Transformer"
def __init__(
- self,
- connection: ConnectionType,
- format: Format = Format.TEXT,
+ self, connection: ConnectionType, *, format: Format = Format.TEXT
):
self._conn = connection
self.format = format
`!None` if the current resultset didn't return tuples.
"""
res = self.pgresult
- if not res or res.status != ExecStatus.TUPLES_OK:
+ if not (res and res.nfields):
return None
return [Column(self, i) for i in range(res.nfields)]
self,
query: Query,
params: Optional[Params] = None,
+ *,
prepare: Optional[bool] = None,
) -> PQGen[None]:
"""Generator implementing `Cursor.execute()`."""
yield from self._start_query(query)
pgq = self._convert_query(query, params)
- yield from self._maybe_prepare_gen(pgq, prepare)
+ results = yield from self._maybe_prepare_gen(pgq, prepare)
+ self._execute_results(results)
self._last_query = query
def _executemany_gen(
else:
pgq.dump(params)
- yield from self._maybe_prepare_gen(pgq, True)
+ results = yield from self._maybe_prepare_gen(pgq, True)
+ self._execute_results(results)
self._last_query = query
def _maybe_prepare_gen(
self, pgq: PostgresQuery, prepare: Optional[bool]
- ) -> PQGen[None]:
+ ) -> PQGen[Sequence["PGresult"]]:
# Check if the query is prepared or needs preparing
prep, name = self._conn._prepared.get(pgq, prepare)
if prep is Prepare.YES:
if cmd:
yield from self._conn._exec_command(cmd)
- self._execute_results(results)
+ return results
def _stream_send_gen(
self, query: Query, params: Optional[Params] = None
f" FROM STDIN statements, got {ExecStatus(status).name}"
)
+ def _close(self) -> None:
+ self._closed = True
+ # however keep the query available, which can be useful for debugging
+ # in case of errors
+ pgq = self._pgq
+ self._reset()
+ self._pgq = pgq
+
class Cursor(BaseCursor["Connection"]):
__module__ = "psycopg3"
"""
Close the current cursor and free associated resources.
"""
- self._closed = True
- # however keep the query available, which can be useful for debugging
- # in case of errors
- pgq = self._pgq
- self._reset()
- self._pgq = pgq
+ self._close()
def execute(
self,
query: Query,
params: Optional[Params] = None,
+ *,
prepare: Optional[bool] = None,
) -> "Cursor":
"""
await self.close()
async def close(self) -> None:
- self._closed = True
- self._reset()
+ self._close()
async def execute(
self,
query: Query,
params: Optional[Params] = None,
+ *,
prepare: Optional[bool] = None,
) -> "AsyncCursor":
async with self._conn.lock:
async with AsyncCopy(self) as copy:
yield copy
-
-
-class NamedCursorMixin:
- pass
-
-
-class NamedCursor(NamedCursorMixin, Cursor):
- pass
-
-
-class AsyncNamedCursor(NamedCursorMixin, AsyncCursor):
- pass
--- /dev/null
+"""
+psycopg3 named cursor objects (server-side cursors)
+"""
+
+# Copyright (C) 2020-2021 The Psycopg Team
+
+import weakref
+import warnings
+from types import TracebackType
+from typing import Any, Generic, Optional, Type, TYPE_CHECKING
+
+from . import sql
+from .pq import Format
+from .cursor import BaseCursor, execute
+from .proto import ConnectionType, Query, Params, PQGen
+
+if TYPE_CHECKING:
+ from .connection import BaseConnection # noqa: F401
+ from .connection import Connection, AsyncConnection # noqa: F401
+
+
+class NamedCursorHelper(Generic[ConnectionType]):
+ __slots__ = ("name", "_wcur")
+
+ def __init__(
+ self,
+ name: str,
+ cursor: BaseCursor[ConnectionType],
+ ):
+ self.name = name
+ self._wcur = weakref.ref(cursor)
+
+ @property
+ def _cur(self) -> BaseCursor[Any]:
+ cur = self._wcur()
+ assert cur
+ return cur
+
+ def _declare_gen(
+ self, query: Query, params: Optional[Params] = None
+ ) -> PQGen[None]:
+ """Generator implementing `NamedCursor.execute()`."""
+ cur = self._cur
+ yield from cur._start_query(query)
+ pgq = cur._convert_query(query, params)
+ cur._execute_send(pgq)
+ results = yield from execute(cur._conn.pgconn)
+ cur._execute_results(results)
+
+ # The above result is an COMMAND_OK. Get the cursor result shape
+ cur._conn.pgconn.send_describe_portal(
+ self.name.encode(cur._conn.client_encoding)
+ )
+ results = yield from execute(cur._conn.pgconn)
+ cur._execute_results(results)
+
+ def _make_declare_statement(
+ self, query: Query, scrollable: bool, hold: bool
+ ) -> sql.Composable:
+ cur = self._cur
+ if isinstance(query, bytes):
+ query = query.decode(cur._conn.client_encoding)
+ if not isinstance(query, sql.Composable):
+ query = sql.SQL(query)
+
+ return sql.SQL(
+ "declare {name} {scroll} cursor{hold} for {query}"
+ ).format(
+ name=sql.Identifier(self.name),
+ scroll=sql.SQL("scroll" if scrollable else "no scroll"),
+ hold=sql.SQL(" with hold" if hold else ""),
+ query=query,
+ )
+
+
+class NamedCursor(BaseCursor["Connection"]):
+ __module__ = "psycopg3"
+ __slots__ = ("_helper",)
+
+ def __init__(
+ self,
+ connection: "Connection",
+ name: str,
+ *,
+ format: Format = Format.TEXT,
+ ):
+ super().__init__(connection, format=format)
+ self._helper = NamedCursorHelper(name, self)
+
+ def __del__(self) -> None:
+ if not self._closed:
+ warnings.warn(
+ f"named cursor {self} was deleted while still open."
+ f" Please use 'with' or '.close()' to close the cursor properly",
+ ResourceWarning,
+ )
+
+ def __enter__(self) -> "NamedCursor":
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.close()
+
+ @property
+ def name(self) -> str:
+ return self._helper.name
+
+ def close(self) -> None:
+ """
+ Close the current cursor and free associated resources.
+ """
+ # TODO close the cursor for real
+ self._close()
+
+ def execute(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ scrollable: bool = True,
+ hold: bool = False,
+ ) -> "NamedCursor":
+ """
+ Execute a query or command to the database.
+ """
+ query = self._helper._make_declare_statement(
+ query, scrollable=scrollable, hold=hold
+ )
+ with self._conn.lock:
+ self._conn.wait(self._helper._declare_gen(query, params))
+ return self
+
+
+class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
+ __module__ = "psycopg3"
+ __slots__ = ("_helper",)
+
+ def __init__(
+ self,
+ connection: "AsyncConnection",
+ name: str,
+ *,
+ format: Format = Format.TEXT,
+ ):
+ super().__init__(connection, format=format)
+ self._helper = NamedCursorHelper(name, self)
+
+ def __del__(self) -> None:
+ if not self._closed:
+ warnings.warn(
+ f"named cursor {self} was deleted while still open."
+ f" Please use 'with' or '.close()' to close the cursor properly",
+ ResourceWarning,
+ )
+
+ async def __aenter__(self) -> "AsyncNamedCursor":
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ await self.close()
+
+ @property
+ def name(self) -> str:
+ return self._helper.name
+
+ async def close(self) -> None:
+ """
+ Close the current cursor and free associated resources.
+ """
+ # TODO close the cursor for real
+ self._close()
+
+ async def execute(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ scrollable: bool = True,
+ hold: bool = False,
+ ) -> "AsyncNamedCursor":
+ """
+ Execute a query or command to the database.
+ """
+ query = self._helper._make_declare_statement(
+ query, scrollable=scrollable, hold=hold
+ )
+ async with self._conn.lock:
+ await self._conn.wait(self._helper._declare_gen(query, params))
+ return self
def describe_prepared(self, name: bytes) -> "PGresult":
...
+ def send_describe_prepared(self, name: bytes) -> None:
+ ...
+
def describe_portal(self, name: bytes) -> "PGresult":
...
+ def send_describe_portal(self, name: bytes) -> None:
+ ...
+
def get_result(self) -> Optional["PGresult"]:
...
--- /dev/null
+def test_description(conn):
+ cur = conn.cursor("foo")
+ assert cur.name == "foo"
+ cur.execute("select generate_series(1, 10) as bar")
+ assert len(cur.description) == 1
+ assert cur.description[0].name == "bar"
+ assert cur.description[0].type_code == cur.adapters.types["int4"].oid
+ assert cur.pgresult.ntuples == 0
--- /dev/null
+import pytest
+
+pytestmark = pytest.mark.asyncio
+
+
+async def test_description(aconn):
+ cur = await aconn.cursor("foo")
+ assert cur.name == "foo"
+ await cur.execute("select generate_series(1, 10) as bar")
+ assert len(cur.description) == 1
+ assert cur.description[0].name == "bar"
+ assert cur.description[0].type_code == cur.adapters.types["int4"].oid
+ assert cur.pgresult.ntuples == 0