From: Daniele Varrazzo Date: Thu, 11 Feb 2021 20:54:11 +0000 (+0100) Subject: Merge branch 'master' into row-factory X-Git-Tag: 3.0.dev0~106^2~21 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=28616a3512444e6076e74862eef27ebff9c71e92;p=thirdparty%2Fpsycopg.git Merge branch 'master' into row-factory --- 28616a3512444e6076e74862eef27ebff9c71e92 diff --cc psycopg3/psycopg3/_typeinfo.py index 7e125d92b,2d6fa2453..286d6c753 --- a/psycopg3/psycopg3/_typeinfo.py +++ b/psycopg3/psycopg3/_typeinfo.py @@@ -88,9 -88,9 +88,10 @@@ class TypeInfo if isinstance(name, Composable): name = name.as_string(conn) - cur = await conn.cursor(binary=True, row_factory=None) - cur = conn.cursor(binary=True) ++ ++ cur = conn.cursor(binary=True, row_factory=None) await cur.execute(cls._info_query, {"name": name}) - recs = await cur.fetchall() + recs: Sequence[Sequence[Any]] = await cur.fetchall() fields = [d[0] for d in cur.description or ()] return cls._fetch(name, fields, recs) diff --cc psycopg3/psycopg3/connection.py index 8f0c93281,b883071f7..fae588238 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@@ -29,8 -28,9 +28,9 @@@ from . import waitin from . import encodings from .pq import ConnStatus, ExecStatus, TransactionStatus, Format from .sql import Composable -from .proto import PQGen, PQGenConn, RV, Query, Params, AdaptContext -from .proto import ConnectionType +from .proto import PQGen, PQGenConn, RV, RowFactory, Query, Params +from .proto import AdaptContext, ConnectionType + from .cursor import Cursor, AsyncCursor from .conninfo import make_conninfo from .generators import notifies from .transaction import Transaction, AsyncTransaction @@@ -448,32 -444,35 +444,50 @@@ class Connection(BaseConnection) """Close the database connection.""" self.pgconn.finish() + @overload - def cursor(self, *, binary: bool = False) -> Cursor: ++ def cursor( ++ self, *, binary: bool = False, row_factory: Optional[RowFactory] = None ++ ) -> Cursor: + ... + + @overload - def cursor(self, name: str, *, binary: bool = False) -> ServerCursor: ++ def cursor( ++ self, ++ name: str, ++ *, ++ binary: bool = False, ++ row_factory: Optional[RowFactory] = None, ++ ) -> ServerCursor: + ... + def cursor( - self, name: str = "", *, binary: bool = False + self, + name: str = "", ++ *, + binary: bool = False, + row_factory: Optional[RowFactory] = None, - ) -> "Cursor": + ) -> Union[Cursor, ServerCursor]: """ - Return a new `Cursor` to send commands and queries to the connection. + Return a new cursor to send commands and queries to the connection. """ - if name: - raise NotImplementedError - format = Format.BINARY if binary else Format.TEXT - return self.cursor_factory( - self, format=format, row_factory=row_factory - ) + if name: - return ServerCursor(self, name=name, format=format) ++ return ServerCursor( ++ self, name=name, format=format, row_factory=row_factory ++ ) + else: - return Cursor(self, format=format) ++ return Cursor(self, format=format, row_factory=row_factory) def execute( self, query: Query, params: Optional[Params] = None, + *, prepare: Optional[bool] = None, + row_factory: Optional[RowFactory] = None, - ) -> "Cursor": + ) -> Cursor: """Execute a query and return a cursor to read its results.""" - cur = self.cursor() + cur = self.cursor(row_factory=row_factory) return cur.execute(query, params, prepare=prepare) def commit(self) -> None: @@@ -591,31 -587,34 +602,49 @@@ class AsyncConnection(BaseConnection) async def close(self) -> None: self.pgconn.finish() - async def cursor( + @overload - def cursor(self, *, binary: bool = False) -> AsyncCursor: ++ def cursor( ++ self, *, binary: bool = False, row_factory: Optional[RowFactory] = None ++ ) -> AsyncCursor: + ... + + @overload - def cursor(self, name: str, *, binary: bool = False) -> AsyncServerCursor: ++ def cursor( ++ self, ++ name: str, ++ *, ++ binary: bool = False, ++ row_factory: Optional[RowFactory] = None, ++ ) -> AsyncServerCursor: + ... + + def cursor( - self, name: str = "", *, binary: bool = False + self, + name: str = "", ++ *, + binary: bool = False, + row_factory: Optional[RowFactory] = None, - ) -> "AsyncCursor": + ) -> Union[AsyncCursor, AsyncServerCursor]: """ Return a new `AsyncCursor` to send commands and queries to the connection. """ - if name: - raise NotImplementedError - format = Format.BINARY if binary else Format.TEXT - return self.cursor_factory( - self, format=format, row_factory=row_factory - ) + if name: - return AsyncServerCursor(self, name=name, format=format) ++ return AsyncServerCursor( ++ self, name=name, format=format, row_factory=row_factory ++ ) + else: - return AsyncCursor(self, format=format) ++ return AsyncCursor(self, format=format, row_factory=row_factory) async def execute( self, query: Query, params: Optional[Params] = None, + *, prepare: Optional[bool] = None, + row_factory: Optional[RowFactory] = None, - ) -> "AsyncCursor": - cur = await self.cursor(row_factory=row_factory) + ) -> AsyncCursor: - cur = self.cursor() ++ cur = self.cursor(row_factory=row_factory) return await cur.execute(query, params, prepare=prepare) async def commit(self) -> None: diff --cc psycopg3/psycopg3/cursor.py index 362fc750e,b341f03a6..c24a3f985 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@@ -59,10 -59,7 +60,11 @@@ class BaseCursor(Generic[ConnectionType _tx: "Transformer" def __init__( - self, connection: ConnectionType, *, format: Format = Format.TEXT + self, + connection: ConnectionType, ++ *, + format: Format = Format.TEXT, + row_factory: Optional[RowFactory] = None, ): self._conn = connection self.format = format @@@ -532,10 -554,10 +564,10 @@@ class Cursor(BaseCursor["Connection"]) self._check_result() assert self.pgresult records = self._tx.load_rows(self._pos, self.pgresult.ntuples) - self._pos += self.pgresult.ntuples + self._pos = self.pgresult.ntuples return records - def __iter__(self) -> Iterator[Sequence[Any]]: + def __iter__(self) -> Iterator[Row]: self._check_result() load = self._tx.load_row @@@ -629,10 -664,10 +674,10 @@@ class AsyncCursor(BaseCursor["AsyncConn self._check_result() assert self.pgresult records = self._tx.load_rows(self._pos, self.pgresult.ntuples) - self._pos += self.pgresult.ntuples + self._pos = self.pgresult.ntuples return records - async def __aiter__(self) -> AsyncIterator[Sequence[Any]]: + async def __aiter__(self) -> AsyncIterator[Row]: self._check_result() load = self._tx.load_row diff --cc psycopg3/psycopg3/server_cursor.py index 000000000,88f7bad0e..30624ac57 mode 000000,100644..100644 --- a/psycopg3/psycopg3/server_cursor.py +++ b/psycopg3/psycopg3/server_cursor.py @@@ -1,0 -1,394 +1,396 @@@ + """ + psycopg3 server-side cursor objects. + """ + + # Copyright (C) 2020-2021 The Psycopg Team + + import warnings + from types import TracebackType + from typing import Any, AsyncIterator, Generic, List, Iterator, Optional + from typing import Sequence, Type, Tuple, TYPE_CHECKING + + from . import pq + from . import sql + from . import errors as e + from .cursor import BaseCursor, execute -from .proto import ConnectionType, Query, Params, PQGen ++from .proto import ConnectionType, Query, Params, PQGen, RowFactory + + if TYPE_CHECKING: + from .connection import BaseConnection # noqa: F401 + from .connection import Connection, AsyncConnection # noqa: F401 + + DEFAULT_ITERSIZE = 100 + + + class ServerCursorHelper(Generic[ConnectionType]): + __slots__ = ("name", "described") + """Helper object for common ServerCursor code. + + TODO: this should be a mixin, but couldn't find a way to work it + correctly with the generic. + """ + + def __init__(self, name: str): + self.name = name + self.described = False + + def _repr(self, cur: BaseCursor[ConnectionType]) -> str: + cls = f"{cur.__class__.__module__}.{cur.__class__.__qualname__}" + info = pq.misc.connection_summary(cur._conn.pgconn) + if cur._closed: + status = "closed" + elif not cur._pgresult: + status = "no result" + else: + status = pq.ExecStatus(cur._pgresult.status).name + return f"<{cls} {self.name!r} [{status}] {info} at 0x{id(cur):x}>" + + def _declare_gen( + self, + cur: BaseCursor[ConnectionType], + query: Query, + params: Optional[Params] = None, + ) -> PQGen[None]: + """Generator implementing `ServerCursor.execute()`.""" + conn = cur._conn + + # If the cursor is being reused, the previous one must be closed. + if self.described: + yield from self._close_gen(cur) + self.described = False + + yield from cur._start_query(query) + pgq = cur._convert_query(query, params) + cur._execute_send(pgq) + results = yield from execute(conn.pgconn) + cur._execute_results(results) + + # The above result is an COMMAND_OK. Get the cursor result shape + yield from self._describe_gen(cur) + + def _describe_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]: + conn = cur._conn + conn.pgconn.send_describe_portal( + self.name.encode(conn.client_encoding) + ) + results = yield from execute(conn.pgconn) + cur._execute_results(results) + self.described = True + + def _close_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]: + # if the connection is not in a sane state, don't even try + if cur._conn.pgconn.transaction_status not in ( + pq.TransactionStatus.IDLE, + pq.TransactionStatus.INTRANS, + ): + return + + # if we didn't declare the cursor ourselves we still have to close it + # but we must make sure it exists. + if not self.described: + query = sql.SQL( + "select 1 from pg_catalog.pg_cursors where name = {}" + ).format(sql.Literal(self.name)) + res = yield from cur._conn._exec_command(query) + if res.ntuples == 0: + return + + query = sql.SQL("close {}").format(sql.Identifier(self.name)) + yield from cur._conn._exec_command(query) + + def _fetch_gen( + self, cur: BaseCursor[ConnectionType], num: Optional[int] + ) -> PQGen[List[Tuple[Any, ...]]]: + # If we are stealing the cursor, make sure we know its shape + if not self.described: + yield from cur._start_query() + yield from self._describe_gen(cur) + + if num is not None: + howmuch: sql.Composable = sql.Literal(num) + else: + howmuch = sql.SQL("all") + + query = sql.SQL("fetch forward {} from {}").format( + howmuch, sql.Identifier(self.name) + ) + res = yield from cur._conn._exec_command(query) + + # TODO: loaders don't need to be refreshed + cur.pgresult = res + return cur._tx.load_rows(0, res.ntuples) + + def _scroll_gen( + self, cur: BaseCursor[ConnectionType], value: int, mode: str + ) -> PQGen[None]: + if mode not in ("relative", "absolute"): + raise ValueError( + f"bad mode: {mode}. It should be 'relative' or 'absolute'" + ) + query = sql.SQL("move{} {} from {}").format( + sql.SQL(" absolute" if mode == "absolute" else ""), + sql.Literal(value), + sql.Identifier(self.name), + ) + yield from cur._conn._exec_command(query) + + def _make_declare_statement( + self, + cur: BaseCursor[ConnectionType], + query: Query, + scrollable: Optional[bool], + hold: bool, + ) -> sql.Composable: + + if isinstance(query, bytes): + query = query.decode(cur._conn.client_encoding) + if not isinstance(query, sql.Composable): + query = sql.SQL(query) + + parts = [ + sql.SQL("declare"), + sql.Identifier(self.name), + ] + if scrollable is not None: + parts.append(sql.SQL("scroll" if scrollable else "no scroll")) + parts.append(sql.SQL("cursor")) + if hold: + parts.append(sql.SQL("with hold")) + parts.append(sql.SQL("for")) + parts.append(query) + + return sql.SQL(" ").join(parts) + + + class ServerCursor(BaseCursor["Connection"]): + __module__ = "psycopg3" + __slots__ = ("_helper", "itersize") + + def __init__( + self, + connection: "Connection", + name: str, + *, + format: pq.Format = pq.Format.TEXT, ++ row_factory: Optional[RowFactory] = None, + ): - super().__init__(connection, format=format) ++ super().__init__(connection, format=format, row_factory=row_factory) + self._helper: ServerCursorHelper["Connection"] = ServerCursorHelper( + name + ) + self.itersize = DEFAULT_ITERSIZE + + def __del__(self) -> None: + if not self._closed: + warnings.warn( + f"the server-side cursor {self} was deleted while still open." + f" Please use 'with' or '.close()' to close the cursor properly", + ResourceWarning, + ) + + def __repr__(self) -> str: + return self._helper._repr(self) + + def __enter__(self) -> "ServerCursor": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + + @property + def name(self) -> str: + """The name of the cursor.""" + return self._helper.name + + def close(self) -> None: + """ + Close the current cursor and free associated resources. + """ + with self._conn.lock: + self._conn.wait(self._helper._close_gen(self)) + self._close() + + def execute( + self, + query: Query, + params: Optional[Params] = None, + *, + scrollable: Optional[bool] = None, + hold: bool = False, + ) -> "ServerCursor": + """ + Open a cursor to execute a query to the database. + """ + query = self._helper._make_declare_statement( + self, query, scrollable=scrollable, hold=hold + ) + with self._conn.lock: + self._conn.wait(self._helper._declare_gen(self, query, params)) + return self + + def executemany(self, query: Query, params_seq: Sequence[Params]) -> None: + """Method not implemented for server-side cursors.""" + raise e.NotSupportedError( + "executemany not supported on server-side cursors" + ) + + def fetchone(self) -> Optional[Sequence[Any]]: + with self._conn.lock: + recs = self._conn.wait(self._helper._fetch_gen(self, 1)) + if recs: + self._pos += 1 + return recs[0] + else: + return None + + def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]: + if not size: + size = self.arraysize + with self._conn.lock: + recs = self._conn.wait(self._helper._fetch_gen(self, size)) + self._pos += len(recs) + return recs + + def fetchall(self) -> Sequence[Sequence[Any]]: + with self._conn.lock: + recs = self._conn.wait(self._helper._fetch_gen(self, None)) + self._pos += len(recs) + return recs + + def __iter__(self) -> Iterator[Sequence[Any]]: + while True: + with self._conn.lock: + recs = self._conn.wait( + self._helper._fetch_gen(self, self.itersize) + ) + for rec in recs: + self._pos += 1 + yield rec + if len(recs) < self.itersize: + break + + def scroll(self, value: int, mode: str = "relative") -> None: + with self._conn.lock: + self._conn.wait(self._helper._scroll_gen(self, value, mode)) + # Postgres doesn't have a reliable way to report a cursor out of bound + if mode == "relative": + self._pos += value + else: + self._pos = value + + + class AsyncServerCursor(BaseCursor["AsyncConnection"]): + __module__ = "psycopg3" + __slots__ = ("_helper", "itersize") + + def __init__( + self, + connection: "AsyncConnection", + name: str, + *, + format: pq.Format = pq.Format.TEXT, ++ row_factory: Optional[RowFactory] = None, + ): - super().__init__(connection, format=format) ++ super().__init__(connection, format=format, row_factory=row_factory) + self._helper: ServerCursorHelper["AsyncConnection"] + self._helper = ServerCursorHelper(name) + self.itersize = DEFAULT_ITERSIZE + + def __del__(self) -> None: + if not self._closed: + warnings.warn( + f"the server-side cursor {self} was deleted while still open." + f" Please use 'with' or '.close()' to close the cursor properly", + ResourceWarning, + ) + + def __repr__(self) -> str: + return self._helper._repr(self) + + async def __aenter__(self) -> "AsyncServerCursor": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() + + @property + def name(self) -> str: + return self._helper.name + + async def close(self) -> None: + async with self._conn.lock: + await self._conn.wait(self._helper._close_gen(self)) + self._close() + + async def execute( + self, + query: Query, + params: Optional[Params] = None, + *, + scrollable: Optional[bool] = None, + hold: bool = False, + ) -> "AsyncServerCursor": + query = self._helper._make_declare_statement( + self, query, scrollable=scrollable, hold=hold + ) + async with self._conn.lock: + await self._conn.wait( + self._helper._declare_gen(self, query, params) + ) + return self + + async def executemany( + self, query: Query, params_seq: Sequence[Params] + ) -> None: + raise e.NotSupportedError( + "executemany not supported on server-side cursors" + ) + + async def fetchone(self) -> Optional[Sequence[Any]]: + async with self._conn.lock: + recs = await self._conn.wait(self._helper._fetch_gen(self, 1)) + if recs: + self._pos += 1 + return recs[0] + else: + return None + + async def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]: + if not size: + size = self.arraysize + async with self._conn.lock: + recs = await self._conn.wait(self._helper._fetch_gen(self, size)) + self._pos += len(recs) + return recs + + async def fetchall(self) -> Sequence[Sequence[Any]]: + async with self._conn.lock: + recs = await self._conn.wait(self._helper._fetch_gen(self, None)) + self._pos += len(recs) + return recs + + async def __aiter__(self) -> AsyncIterator[Sequence[Any]]: + while True: + async with self._conn.lock: + recs = await self._conn.wait( + self._helper._fetch_gen(self, self.itersize) + ) + for rec in recs: + self._pos += 1 + yield rec + if len(recs) < self.itersize: + break + + async def scroll(self, value: int, mode: str = "relative") -> None: + async with self._conn.lock: + await self._conn.wait(self._helper._scroll_gen(self, value, mode)) diff --cc tests/test_cursor.py index 9db2ab358,d22888e1f..7ff25f0a3 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@@ -263,22 -286,48 +286,64 @@@ def test_iter_stop(conn) assert list(cur) == [] +def test_row_factory(conn): + def my_row_factory(cur): + return lambda values: [-v for v in values] + + cur = conn.cursor(row_factory=my_row_factory) + cur.execute("select generate_series(1, 3)") + r = cur.fetchall() + assert r == [[-1], [-2], [-3]] + + cur.execute("select 42; select generate_series(1,3)") + assert cur.fetchall() == [[-42]] + assert cur.nextset() + assert cur.fetchall() == [[-1], [-2], [-3]] + assert cur.nextset() is None + + + def test_scroll(conn): + cur = conn.cursor() + with pytest.raises(psycopg3.ProgrammingError): + cur.scroll(0) + + cur.execute("select generate_series(0,9)") + cur.scroll(2) + assert cur.fetchone() == (2,) + cur.scroll(2) + assert cur.fetchone() == (5,) + cur.scroll(2, mode="relative") + assert cur.fetchone() == (8,) + cur.scroll(-1) + assert cur.fetchone() == (8,) + cur.scroll(-2) + assert cur.fetchone() == (7,) + cur.scroll(2, mode="absolute") + assert cur.fetchone() == (2,) + + # on the boundary + cur.scroll(0, mode="absolute") + assert cur.fetchone() == (0,) + with pytest.raises(IndexError): + cur.scroll(-1, mode="absolute") + + cur.scroll(0, mode="absolute") + with pytest.raises(IndexError): + cur.scroll(-1) + + cur.scroll(9, mode="absolute") + assert cur.fetchone() == (9,) + with pytest.raises(IndexError): + cur.scroll(10, mode="absolute") + + cur.scroll(9, mode="absolute") + with pytest.raises(IndexError): + cur.scroll(1) + + with pytest.raises(ValueError): + cur.scroll(1, "wat") + + def test_query_params_execute(conn): cur = conn.cursor() assert cur.query is None diff --cc tests/test_cursor_async.py index fba7fb502,4aeb39666..e56777323 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@@ -268,32 -291,50 +291,74 @@@ async def test_iter_stop(aconn) assert False +async def test_row_factory(aconn): + def my_row_factory(cursor): + def mkrow(values): + assert cursor.description is not None + titles = [c.name for c in cursor.description] + return [ + f"{value.upper()}{title}" + for title, value in zip(titles, values) + ] + + return mkrow + - cur = await aconn.cursor(row_factory=my_row_factory) ++ cur = aconn.cursor(row_factory=my_row_factory) + await cur.execute("select 'foo' as bar") + (r,) = await cur.fetchone() + assert r == "FOObar" + + await cur.execute("select 'x' as x; select 'y' as y, 'z' as z") + assert await cur.fetchall() == [["Xx"]] + assert cur.nextset() + assert await cur.fetchall() == [["Yy", "Zz"]] + assert cur.nextset() is None + + + async def test_scroll(aconn): + cur = aconn.cursor() + with pytest.raises(psycopg3.ProgrammingError): + await cur.scroll(0) + + await cur.execute("select generate_series(0,9)") + await cur.scroll(2) + assert await cur.fetchone() == (2,) + await cur.scroll(2) + assert await cur.fetchone() == (5,) + await cur.scroll(2, mode="relative") + assert await cur.fetchone() == (8,) + await cur.scroll(-1) + assert await cur.fetchone() == (8,) + await cur.scroll(-2) + assert await cur.fetchone() == (7,) + await cur.scroll(2, mode="absolute") + assert await cur.fetchone() == (2,) + + # on the boundary + await cur.scroll(0, mode="absolute") + assert await cur.fetchone() == (0,) + with pytest.raises(IndexError): + await cur.scroll(-1, mode="absolute") + + await cur.scroll(0, mode="absolute") + with pytest.raises(IndexError): + await cur.scroll(-1) + + await cur.scroll(9, mode="absolute") + assert await cur.fetchone() == (9,) + with pytest.raises(IndexError): + await cur.scroll(10, mode="absolute") + + await cur.scroll(9, mode="absolute") + with pytest.raises(IndexError): + await cur.scroll(1) + + with pytest.raises(ValueError): + await cur.scroll(1, "wat") + + async def test_query_params_execute(aconn): - cur = await aconn.cursor() + cur = aconn.cursor() assert cur.query is None assert cur.params is None