from . import encodings
from .pq import ConnStatus, ExecStatus, TransactionStatus, Format
from .sql import Composable
-from .proto import PQGen, PQGenConn, RV, Query, Params, AdaptContext
-from .proto import ConnectionType
+from .proto import PQGen, PQGenConn, RV, RowFactory, Query, Params
+from .proto import AdaptContext, ConnectionType
+ from .cursor import Cursor, AsyncCursor
from .conninfo import make_conninfo
from .generators import notifies
from .transaction import Transaction, AsyncTransaction
"""Close the database connection."""
self.pgconn.finish()
- def cursor(self, *, binary: bool = False) -> Cursor:
+ @overload
- def cursor(self, name: str, *, binary: bool = False) -> ServerCursor:
++ def cursor(
++ self, *, binary: bool = False, row_factory: Optional[RowFactory] = None
++ ) -> Cursor:
+ ...
+
+ @overload
++ def cursor(
++ self,
++ name: str,
++ *,
++ binary: bool = False,
++ row_factory: Optional[RowFactory] = None,
++ ) -> ServerCursor:
+ ...
+
def cursor(
- self, name: str = "", *, binary: bool = False
+ self,
+ name: str = "",
++ *,
+ binary: bool = False,
+ row_factory: Optional[RowFactory] = None,
- ) -> "Cursor":
+ ) -> Union[Cursor, ServerCursor]:
"""
- Return a new `Cursor` to send commands and queries to the connection.
+ Return a new cursor to send commands and queries to the connection.
"""
- if name:
- raise NotImplementedError
-
format = Format.BINARY if binary else Format.TEXT
- return self.cursor_factory(
- self, format=format, row_factory=row_factory
- )
+ if name:
- return ServerCursor(self, name=name, format=format)
++ return ServerCursor(
++ self, name=name, format=format, row_factory=row_factory
++ )
+ else:
- return Cursor(self, format=format)
++ return Cursor(self, format=format, row_factory=row_factory)
def execute(
self,
query: Query,
params: Optional[Params] = None,
+ *,
prepare: Optional[bool] = None,
- ) -> "Cursor":
+ row_factory: Optional[RowFactory] = None,
+ ) -> Cursor:
"""Execute a query and return a cursor to read its results."""
- cur = self.cursor()
+ cur = self.cursor(row_factory=row_factory)
return cur.execute(query, params, prepare=prepare)
def commit(self) -> None:
async def close(self) -> None:
self.pgconn.finish()
- async def cursor(
+ @overload
- def cursor(self, *, binary: bool = False) -> AsyncCursor:
++ def cursor(
++ self, *, binary: bool = False, row_factory: Optional[RowFactory] = None
++ ) -> AsyncCursor:
+ ...
+
+ @overload
- def cursor(self, name: str, *, binary: bool = False) -> AsyncServerCursor:
++ def cursor(
++ self,
++ name: str,
++ *,
++ binary: bool = False,
++ row_factory: Optional[RowFactory] = None,
++ ) -> AsyncServerCursor:
+ ...
+
+ def cursor(
- self, name: str = "", *, binary: bool = False
+ self,
+ name: str = "",
++ *,
+ binary: bool = False,
+ row_factory: Optional[RowFactory] = None,
- ) -> "AsyncCursor":
+ ) -> Union[AsyncCursor, AsyncServerCursor]:
"""
Return a new `AsyncCursor` to send commands and queries to the connection.
"""
- if name:
- raise NotImplementedError
-
format = Format.BINARY if binary else Format.TEXT
- return self.cursor_factory(
- self, format=format, row_factory=row_factory
- )
+ if name:
- return AsyncServerCursor(self, name=name, format=format)
++ return AsyncServerCursor(
++ self, name=name, format=format, row_factory=row_factory
++ )
+ else:
- return AsyncCursor(self, format=format)
++ return AsyncCursor(self, format=format, row_factory=row_factory)
async def execute(
self,
query: Query,
params: Optional[Params] = None,
+ *,
prepare: Optional[bool] = None,
- ) -> "AsyncCursor":
- cur = await self.cursor(row_factory=row_factory)
+ row_factory: Optional[RowFactory] = None,
- cur = self.cursor()
+ ) -> AsyncCursor:
++ cur = self.cursor(row_factory=row_factory)
return await cur.execute(query, params, prepare=prepare)
async def commit(self) -> None:
--- /dev/null
-from .proto import ConnectionType, Query, Params, PQGen
+ """
+ psycopg3 server-side cursor objects.
+ """
+
+ # Copyright (C) 2020-2021 The Psycopg Team
+
+ import warnings
+ from types import TracebackType
+ from typing import Any, AsyncIterator, Generic, List, Iterator, Optional
+ from typing import Sequence, Type, Tuple, TYPE_CHECKING
+
+ from . import pq
+ from . import sql
+ from . import errors as e
+ from .cursor import BaseCursor, execute
- super().__init__(connection, format=format)
++from .proto import ConnectionType, Query, Params, PQGen, RowFactory
+
+ if TYPE_CHECKING:
+ from .connection import BaseConnection # noqa: F401
+ from .connection import Connection, AsyncConnection # noqa: F401
+
+ DEFAULT_ITERSIZE = 100
+
+
+ class ServerCursorHelper(Generic[ConnectionType]):
+ __slots__ = ("name", "described")
+ """Helper object for common ServerCursor code.
+
+ TODO: this should be a mixin, but couldn't find a way to work it
+ correctly with the generic.
+ """
+
+ def __init__(self, name: str):
+ self.name = name
+ self.described = False
+
+ def _repr(self, cur: BaseCursor[ConnectionType]) -> str:
+ cls = f"{cur.__class__.__module__}.{cur.__class__.__qualname__}"
+ info = pq.misc.connection_summary(cur._conn.pgconn)
+ if cur._closed:
+ status = "closed"
+ elif not cur._pgresult:
+ status = "no result"
+ else:
+ status = pq.ExecStatus(cur._pgresult.status).name
+ return f"<{cls} {self.name!r} [{status}] {info} at 0x{id(cur):x}>"
+
+ def _declare_gen(
+ self,
+ cur: BaseCursor[ConnectionType],
+ query: Query,
+ params: Optional[Params] = None,
+ ) -> PQGen[None]:
+ """Generator implementing `ServerCursor.execute()`."""
+ conn = cur._conn
+
+ # If the cursor is being reused, the previous one must be closed.
+ if self.described:
+ yield from self._close_gen(cur)
+ self.described = False
+
+ yield from cur._start_query(query)
+ pgq = cur._convert_query(query, params)
+ cur._execute_send(pgq)
+ results = yield from execute(conn.pgconn)
+ cur._execute_results(results)
+
+ # The above result is an COMMAND_OK. Get the cursor result shape
+ yield from self._describe_gen(cur)
+
+ def _describe_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]:
+ conn = cur._conn
+ conn.pgconn.send_describe_portal(
+ self.name.encode(conn.client_encoding)
+ )
+ results = yield from execute(conn.pgconn)
+ cur._execute_results(results)
+ self.described = True
+
+ def _close_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]:
+ # if the connection is not in a sane state, don't even try
+ if cur._conn.pgconn.transaction_status not in (
+ pq.TransactionStatus.IDLE,
+ pq.TransactionStatus.INTRANS,
+ ):
+ return
+
+ # if we didn't declare the cursor ourselves we still have to close it
+ # but we must make sure it exists.
+ if not self.described:
+ query = sql.SQL(
+ "select 1 from pg_catalog.pg_cursors where name = {}"
+ ).format(sql.Literal(self.name))
+ res = yield from cur._conn._exec_command(query)
+ if res.ntuples == 0:
+ return
+
+ query = sql.SQL("close {}").format(sql.Identifier(self.name))
+ yield from cur._conn._exec_command(query)
+
+ def _fetch_gen(
+ self, cur: BaseCursor[ConnectionType], num: Optional[int]
+ ) -> PQGen[List[Tuple[Any, ...]]]:
+ # If we are stealing the cursor, make sure we know its shape
+ if not self.described:
+ yield from cur._start_query()
+ yield from self._describe_gen(cur)
+
+ if num is not None:
+ howmuch: sql.Composable = sql.Literal(num)
+ else:
+ howmuch = sql.SQL("all")
+
+ query = sql.SQL("fetch forward {} from {}").format(
+ howmuch, sql.Identifier(self.name)
+ )
+ res = yield from cur._conn._exec_command(query)
+
+ # TODO: loaders don't need to be refreshed
+ cur.pgresult = res
+ return cur._tx.load_rows(0, res.ntuples)
+
+ def _scroll_gen(
+ self, cur: BaseCursor[ConnectionType], value: int, mode: str
+ ) -> PQGen[None]:
+ if mode not in ("relative", "absolute"):
+ raise ValueError(
+ f"bad mode: {mode}. It should be 'relative' or 'absolute'"
+ )
+ query = sql.SQL("move{} {} from {}").format(
+ sql.SQL(" absolute" if mode == "absolute" else ""),
+ sql.Literal(value),
+ sql.Identifier(self.name),
+ )
+ yield from cur._conn._exec_command(query)
+
+ def _make_declare_statement(
+ self,
+ cur: BaseCursor[ConnectionType],
+ query: Query,
+ scrollable: Optional[bool],
+ hold: bool,
+ ) -> sql.Composable:
+
+ if isinstance(query, bytes):
+ query = query.decode(cur._conn.client_encoding)
+ if not isinstance(query, sql.Composable):
+ query = sql.SQL(query)
+
+ parts = [
+ sql.SQL("declare"),
+ sql.Identifier(self.name),
+ ]
+ if scrollable is not None:
+ parts.append(sql.SQL("scroll" if scrollable else "no scroll"))
+ parts.append(sql.SQL("cursor"))
+ if hold:
+ parts.append(sql.SQL("with hold"))
+ parts.append(sql.SQL("for"))
+ parts.append(query)
+
+ return sql.SQL(" ").join(parts)
+
+
+ class ServerCursor(BaseCursor["Connection"]):
+ __module__ = "psycopg3"
+ __slots__ = ("_helper", "itersize")
+
+ def __init__(
+ self,
+ connection: "Connection",
+ name: str,
+ *,
+ format: pq.Format = pq.Format.TEXT,
++ row_factory: Optional[RowFactory] = None,
+ ):
- super().__init__(connection, format=format)
++ super().__init__(connection, format=format, row_factory=row_factory)
+ self._helper: ServerCursorHelper["Connection"] = ServerCursorHelper(
+ name
+ )
+ self.itersize = DEFAULT_ITERSIZE
+
+ def __del__(self) -> None:
+ if not self._closed:
+ warnings.warn(
+ f"the server-side cursor {self} was deleted while still open."
+ f" Please use 'with' or '.close()' to close the cursor properly",
+ ResourceWarning,
+ )
+
+ def __repr__(self) -> str:
+ return self._helper._repr(self)
+
+ def __enter__(self) -> "ServerCursor":
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.close()
+
+ @property
+ def name(self) -> str:
+ """The name of the cursor."""
+ return self._helper.name
+
+ def close(self) -> None:
+ """
+ Close the current cursor and free associated resources.
+ """
+ with self._conn.lock:
+ self._conn.wait(self._helper._close_gen(self))
+ self._close()
+
+ def execute(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ scrollable: Optional[bool] = None,
+ hold: bool = False,
+ ) -> "ServerCursor":
+ """
+ Open a cursor to execute a query to the database.
+ """
+ query = self._helper._make_declare_statement(
+ self, query, scrollable=scrollable, hold=hold
+ )
+ with self._conn.lock:
+ self._conn.wait(self._helper._declare_gen(self, query, params))
+ return self
+
+ def executemany(self, query: Query, params_seq: Sequence[Params]) -> None:
+ """Method not implemented for server-side cursors."""
+ raise e.NotSupportedError(
+ "executemany not supported on server-side cursors"
+ )
+
+ def fetchone(self) -> Optional[Sequence[Any]]:
+ with self._conn.lock:
+ recs = self._conn.wait(self._helper._fetch_gen(self, 1))
+ if recs:
+ self._pos += 1
+ return recs[0]
+ else:
+ return None
+
+ def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]:
+ if not size:
+ size = self.arraysize
+ with self._conn.lock:
+ recs = self._conn.wait(self._helper._fetch_gen(self, size))
+ self._pos += len(recs)
+ return recs
+
+ def fetchall(self) -> Sequence[Sequence[Any]]:
+ with self._conn.lock:
+ recs = self._conn.wait(self._helper._fetch_gen(self, None))
+ self._pos += len(recs)
+ return recs
+
+ def __iter__(self) -> Iterator[Sequence[Any]]:
+ while True:
+ with self._conn.lock:
+ recs = self._conn.wait(
+ self._helper._fetch_gen(self, self.itersize)
+ )
+ for rec in recs:
+ self._pos += 1
+ yield rec
+ if len(recs) < self.itersize:
+ break
+
+ def scroll(self, value: int, mode: str = "relative") -> None:
+ with self._conn.lock:
+ self._conn.wait(self._helper._scroll_gen(self, value, mode))
+ # Postgres doesn't have a reliable way to report a cursor out of bound
+ if mode == "relative":
+ self._pos += value
+ else:
+ self._pos = value
+
+
+ class AsyncServerCursor(BaseCursor["AsyncConnection"]):
+ __module__ = "psycopg3"
+ __slots__ = ("_helper", "itersize")
+
+ def __init__(
+ self,
+ connection: "AsyncConnection",
+ name: str,
+ *,
+ format: pq.Format = pq.Format.TEXT,
++ row_factory: Optional[RowFactory] = None,
+ ):
++ super().__init__(connection, format=format, row_factory=row_factory)
+ self._helper: ServerCursorHelper["AsyncConnection"]
+ self._helper = ServerCursorHelper(name)
+ self.itersize = DEFAULT_ITERSIZE
+
+ def __del__(self) -> None:
+ if not self._closed:
+ warnings.warn(
+ f"the server-side cursor {self} was deleted while still open."
+ f" Please use 'with' or '.close()' to close the cursor properly",
+ ResourceWarning,
+ )
+
+ def __repr__(self) -> str:
+ return self._helper._repr(self)
+
+ async def __aenter__(self) -> "AsyncServerCursor":
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ await self.close()
+
+ @property
+ def name(self) -> str:
+ return self._helper.name
+
+ async def close(self) -> None:
+ async with self._conn.lock:
+ await self._conn.wait(self._helper._close_gen(self))
+ self._close()
+
+ async def execute(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ scrollable: Optional[bool] = None,
+ hold: bool = False,
+ ) -> "AsyncServerCursor":
+ query = self._helper._make_declare_statement(
+ self, query, scrollable=scrollable, hold=hold
+ )
+ async with self._conn.lock:
+ await self._conn.wait(
+ self._helper._declare_gen(self, query, params)
+ )
+ return self
+
+ async def executemany(
+ self, query: Query, params_seq: Sequence[Params]
+ ) -> None:
+ raise e.NotSupportedError(
+ "executemany not supported on server-side cursors"
+ )
+
+ async def fetchone(self) -> Optional[Sequence[Any]]:
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._helper._fetch_gen(self, 1))
+ if recs:
+ self._pos += 1
+ return recs[0]
+ else:
+ return None
+
+ async def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]:
+ if not size:
+ size = self.arraysize
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._helper._fetch_gen(self, size))
+ self._pos += len(recs)
+ return recs
+
+ async def fetchall(self) -> Sequence[Sequence[Any]]:
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._helper._fetch_gen(self, None))
+ self._pos += len(recs)
+ return recs
+
+ async def __aiter__(self) -> AsyncIterator[Sequence[Any]]:
+ while True:
+ async with self._conn.lock:
+ recs = await self._conn.wait(
+ self._helper._fetch_gen(self, self.itersize)
+ )
+ for rec in recs:
+ self._pos += 1
+ yield rec
+ if len(recs) < self.itersize:
+ break
+
+ async def scroll(self, value: int, mode: str = "relative") -> None:
+ async with self._conn.lock:
+ await self._conn.wait(self._helper._scroll_gen(self, value, mode))
assert False
- cur = await aconn.cursor(row_factory=my_row_factory)
+async def test_row_factory(aconn):
+ def my_row_factory(cursor):
+ def mkrow(values):
+ assert cursor.description is not None
+ titles = [c.name for c in cursor.description]
+ return [
+ f"{value.upper()}{title}"
+ for title, value in zip(titles, values)
+ ]
+
+ return mkrow
+
++ cur = aconn.cursor(row_factory=my_row_factory)
+ await cur.execute("select 'foo' as bar")
+ (r,) = await cur.fetchone()
+ assert r == "FOObar"
+
+ await cur.execute("select 'x' as x; select 'y' as y, 'z' as z")
+ assert await cur.fetchall() == [["Xx"]]
+ assert cur.nextset()
+ assert await cur.fetchall() == [["Yy", "Zz"]]
+ assert cur.nextset() is None
+
+
+ async def test_scroll(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg3.ProgrammingError):
+ await cur.scroll(0)
+
+ await cur.execute("select generate_series(0,9)")
+ await cur.scroll(2)
+ assert await cur.fetchone() == (2,)
+ await cur.scroll(2)
+ assert await cur.fetchone() == (5,)
+ await cur.scroll(2, mode="relative")
+ assert await cur.fetchone() == (8,)
+ await cur.scroll(-1)
+ assert await cur.fetchone() == (8,)
+ await cur.scroll(-2)
+ assert await cur.fetchone() == (7,)
+ await cur.scroll(2, mode="absolute")
+ assert await cur.fetchone() == (2,)
+
+ # on the boundary
+ await cur.scroll(0, mode="absolute")
+ assert await cur.fetchone() == (0,)
+ with pytest.raises(IndexError):
+ await cur.scroll(-1, mode="absolute")
+
+ await cur.scroll(0, mode="absolute")
+ with pytest.raises(IndexError):
+ await cur.scroll(-1)
+
+ await cur.scroll(9, mode="absolute")
+ assert await cur.fetchone() == (9,)
+ with pytest.raises(IndexError):
+ await cur.scroll(10, mode="absolute")
+
+ await cur.scroll(9, mode="absolute")
+ with pytest.raises(IndexError):
+ await cur.scroll(1)
+
+ with pytest.raises(ValueError):
+ await cur.scroll(1, "wat")
+
+
async def test_query_params_execute(aconn):
- cur = await aconn.cursor()
+ cur = aconn.cursor()
assert cur.query is None
assert cur.params is None