From: Daniele Varrazzo Date: Thu, 12 May 2022 02:25:51 +0000 (+0200) Subject: feat: add ClientCursor and AsyncClientCursor classes X-Git-Tag: 3.1~99^2~8 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=462b06665b8d6e810dbcd7c6187f998bf93899e4;p=thirdparty%2Fpsycopg.git feat: add ClientCursor and AsyncClientCursor classes --- diff --git a/psycopg/psycopg/__init__.py b/psycopg/psycopg/__init__.py index c86d26822..baadf30c3 100644 --- a/psycopg/psycopg/__init__.py +++ b/psycopg/psycopg/__init__.py @@ -23,6 +23,7 @@ from .connection import BaseConnection, Connection, Notify from .transaction import Rollback, Transaction, AsyncTransaction from .cursor_async import AsyncCursor from .server_cursor import AsyncServerCursor, ServerCursor +from .client_cursor import AsyncClientCursor, ClientCursor from .connection_async import AsyncConnection from . import dbapi20 @@ -57,6 +58,7 @@ types.array.register_all_arrays(adapters) # this is the canonical place to obtain them and should be used by MyPy too, # so that function signatures are consistent with the documentation. __all__ = [ + "AsyncClientCursor", "AsyncConnection", "AsyncCopy", "AsyncCursor", @@ -64,6 +66,7 @@ __all__ = [ "AsyncServerCursor", "AsyncTransaction", "BaseConnection", + "ClientCursor", "Column", "Connection", "ConnectionInfo", diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py index 250454d90..2b0838e34 100644 --- a/psycopg/psycopg/_queries.py +++ b/psycopg/psycopg/_queries.py @@ -97,6 +97,53 @@ class PostgresQuery: self.formats = None +class PostgresClientQuery(PostgresQuery): + """ + PostgresQuery subclass merging query and arguments client-side. + """ + + __slots__ = ("template",) + + def convert(self, query: Query, vars: Optional[Params]) -> None: + """ + Set up the query and parameters to convert. + + The results of this function can be obtained accessing the object + attributes (`query`, `params`, `types`, `formats`). + """ + if isinstance(query, str): + bquery = query.encode(self._encoding) + elif isinstance(query, Composable): + bquery = query.as_bytes(self._tx) + else: + bquery = query + + if vars is not None: + (self.template, self._order, self._parts) = _query2pg_client( + bquery, self._encoding + ) + else: + self.query = bquery + self._order = None + + self.dump(vars) + + def dump(self, vars: Optional[Params]) -> None: + """ + Process a new set of variables on the query processed by `convert()`. + + This method updates `params` and `types`. + """ + if vars is not None: + params = _validate_and_reorder_params(self._parts, vars, self._order) + self.params = tuple( + self._tx.as_literal(p) if p is not None else b"NULL" for p in params + ) + self.query = self.template % self.params + else: + self.params = None + + @lru_cache() def _query2pg( query: bytes, encoding: str @@ -106,7 +153,7 @@ def _query2pg( - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres format (``$1``, ``$2``) - - placeholders can be %s or %b (text or binary) + - placeholders can be %s, %t, or %b (auto, text or binary) - return ``query`` (bytes), ``formats`` (list of formats) ``order`` (sequence of names used in the query, in the position they appear) ``parts`` (splits of queries and placeholders). @@ -148,6 +195,43 @@ def _query2pg( return b"".join(chunks), formats, order, parts +@lru_cache() +def _query2pg_client( + query: bytes, encoding: str +) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]: + """ + Convert Python query and params into a template to perform client-side binding + """ + parts = _split_query(query, encoding) + order: Optional[List[str]] = None + chunks: List[bytes] = [] + + if isinstance(parts[0].item, int): + for part in parts[:-1]: + assert isinstance(part.item, int) + chunks.append(part.pre) + chunks.append(b"%s") + + elif isinstance(parts[0].item, str): + seen: Dict[str, Tuple[bytes, PyFormat]] = {} + order = [] + for part in parts[:-1]: + assert isinstance(part.item, str) + chunks.append(part.pre) + if part.item not in seen: + ph = b"%s" + seen[part.item] = (ph, part.format) + order.append(part.item) + chunks.append(ph) + else: + chunks.append(seen[part.item][0]) + + # last part + chunks.append(parts[-1].pre) + + return b"".join(chunks), order, parts + + def _validate_and_reorder_params( parts: List[QueryPart], vars: Params, order: Optional[List[str]] ) -> Sequence[Any]: diff --git a/psycopg/psycopg/client_cursor.py b/psycopg/psycopg/client_cursor.py new file mode 100644 index 000000000..177fb7840 --- /dev/null +++ b/psycopg/psycopg/client_cursor.py @@ -0,0 +1,81 @@ +""" +psycopg client-side binding cursors +""" + +# Copyright (C) 2022 The Psycopg Team + +from typing import Optional, Tuple, TYPE_CHECKING +from functools import partial + +from ._queries import PostgresQuery, PostgresClientQuery + +from . import errors as e +from .pq import Format +from .abc import ConnectionType, Query, Params +from .rows import Row +from .cursor import BaseCursor, Cursor +from ._preparing import Prepare +from .cursor_async import AsyncCursor + +if TYPE_CHECKING: + from .connection import Connection # noqa: F401 + from .connection_async import AsyncConnection # noqa: F401 + + +class ClientCursorMixin(BaseCursor[ConnectionType, Row]): + def _execute_send( + self, + query: PostgresQuery, + *, + no_pqexec: bool = False, + binary: Optional[bool] = None, + ) -> None: + if binary is None: + fmt = self.format + else: + fmt = Format.BINARY if binary else Format.TEXT + + if fmt == Format.BINARY: + raise e.NotSupportedError( + "client-side cursors don't support binary results" + ) + + if no_pqexec: + raise e.NotSupportedError( + "PQexec operations not supported by client-side cursors" + ) + + self._query = query + # if we don't have to, let's use exec_ as it can run more than + # one query in one go + if self._conn._pipeline: + self._conn._pipeline.command_queue.append( + partial(self._pgconn.send_query, query.query) + ) + else: + self._pgconn.send_query(query.query) + + def _convert_query( + self, query: Query, params: Optional[Params] = None + ) -> PostgresQuery: + pgq = PostgresClientQuery(self._tx) + pgq.convert(query, params) + return pgq + + def _get_prepared( + self, pgq: PostgresQuery, prepare: Optional[bool] = None + ) -> Tuple[Prepare, bytes]: + return (Prepare.NO, b"") + + def _is_pipeline_supported(self) -> bool: + return False + + +class ClientCursor(ClientCursorMixin["Connection[Row]", Row], Cursor[Row]): + pass + + +class AsyncClientCursor( + ClientCursorMixin["AsyncConnection[Row]", Row], AsyncCursor[Row] +): + pass diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 22a10e9e4..12db07a3f 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -7,14 +7,13 @@ psycopg cursor objects from functools import partial from types import TracebackType from typing import Any, Generic, Iterable, Iterator, List -from typing import Optional, NoReturn, Sequence, Type, TypeVar +from typing import Optional, NoReturn, Sequence, Tuple, Type, TypeVar from typing import overload, TYPE_CHECKING from contextlib import contextmanager from . import pq from . import adapt from . import errors as e - from .pq import ExecStatus, Format from .abc import ConnectionType, Query, Params, PQGen from .copy import Copy @@ -293,7 +292,7 @@ class BaseCursor(Generic[ConnectionType, Row]): binary: Optional[bool] = None, ) -> PQGen[Optional[List["PGresult"]]]: # Check if the query is prepared or needs preparing - prep, name = self._conn._prepared.get(pgq, prepare) + prep, name = self._get_prepared(pgq, prepare) if prep is Prepare.NO: # The query must be executed without preparing self._execute_send(pgq, binary=binary) @@ -328,6 +327,11 @@ class BaseCursor(Generic[ConnectionType, Row]): return results + def _get_prepared( + self, pgq: PostgresQuery, prepare: Optional[bool] = None + ) -> Tuple[Prepare, bytes]: + return self._conn._prepared.get(pgq, prepare) + def _stream_send_gen( self, query: Query, @@ -722,7 +726,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): Execute the same command with a sequence of input data. """ try: - if Pipeline.is_supported(): + if self._is_pipeline_supported(): # If there is already a pipeline, ride it, in order to avoid # sending unnecessary Sync. with self._conn.lock: @@ -745,6 +749,9 @@ class Cursor(BaseCursor["Connection[Any]", Row]): except e.Error as ex: raise ex.with_traceback(None) + def _is_pipeline_supported(self) -> bool: + return Pipeline.is_supported() + def stream( self, query: Query, diff --git a/tests/test_client_cursor.py b/tests/test_client_cursor.py new file mode 100644 index 000000000..d5b6193f1 --- /dev/null +++ b/tests/test_client_cursor.py @@ -0,0 +1,769 @@ +import gc +import pickle +import weakref +from typing import List + +import pytest + +import psycopg +from psycopg import sql, rows +from psycopg.adapt import PyFormat +from psycopg.postgres import types as builtins + +from .utils import gc_collect +from .test_cursor import my_row_factory + + +@pytest.fixture +def conn(conn): + conn.cursor_factory = psycopg.ClientCursor + return conn + + +def test_init(conn): + cur = psycopg.ClientCursor(conn) + cur.execute("select 1") + assert cur.fetchone() == (1,) + + conn.row_factory = rows.dict_row + cur = psycopg.ClientCursor(conn) + cur.execute("select 1 as a") + assert cur.fetchone() == {"a": 1} + + +def test_init_factory(conn): + cur = psycopg.ClientCursor(conn, row_factory=rows.dict_row) + cur.execute("select 1 as a") + assert cur.fetchone() == {"a": 1} + + +def test_from_cursor_factory(dsn): + with psycopg.connect(dsn, cursor_factory=psycopg.ClientCursor) as aconn: + cur = aconn.cursor() + assert type(cur) is psycopg.ClientCursor + + cur.execute("select %s", (1,)) + assert cur.fetchone() == (1,) + assert cur._query + assert cur._query.query == b"select 1" + + +def test_close(conn): + cur = conn.cursor() + assert not cur.closed + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.execute("select 'foo'") + + cur.close() + assert cur.closed + + +def test_cursor_close_fetchone(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + for _ in range(5): + cur.fetchone() + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchone() + + +def test_cursor_close_fetchmany(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + assert len(cur.fetchmany(2)) == 2 + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchmany(2) + + +def test_cursor_close_fetchall(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + assert len(cur.fetchall()) == 10 + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchall() + + +def test_context(conn): + with conn.cursor() as cur: + assert not cur.closed + + assert cur.closed + + +@pytest.mark.slow +def test_weakref(conn): + cur = conn.cursor() + w = weakref.ref(cur) + cur.close() + del cur + gc_collect() + assert w() is None + + +def test_pgresult(conn): + cur = conn.cursor() + cur.execute("select 1") + assert cur.pgresult + cur.close() + assert not cur.pgresult + + +def test_statusmessage(conn): + cur = conn.cursor() + assert cur.statusmessage is None + + cur.execute("select generate_series(1, 10)") + assert cur.statusmessage == "SELECT 10" + + cur.execute("create table statusmessage ()") + assert cur.statusmessage == "CREATE TABLE" + + with pytest.raises(psycopg.ProgrammingError): + cur.execute("wat") + assert cur.statusmessage is None + + +def test_execute_sql(conn): + cur = conn.cursor() + cur.execute(sql.SQL("select {value}").format(value="hello")) + assert cur.fetchone() == ("hello",) + + +def test_execute_many_results(conn): + cur = conn.cursor() + assert cur.nextset() is None + + rv = cur.execute("select 'foo'; select generate_series(1,3)") + assert rv is cur + assert cur.fetchall() == [("foo",)] + assert cur.rowcount == 1 + assert cur.nextset() + assert cur.fetchall() == [(1,), (2,), (3,)] + assert cur.nextset() is None + + cur.close() + assert cur.nextset() is None + + +def test_execute_sequence(conn): + cur = conn.cursor() + rv = cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert rv is cur + assert len(cur._results) == 1 + assert cur.pgresult.get_value(0, 0) == b"1" + assert cur.pgresult.get_value(0, 1) == b"foo" + assert cur.pgresult.get_value(0, 2) is None + assert cur.nextset() is None + + +@pytest.mark.parametrize("query", ["", " ", ";"]) +def test_execute_empty_query(conn, query): + cur = conn.cursor() + cur.execute(query) + assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + + +def test_execute_type_change(conn): + # issue #112 + conn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = conn.cursor() + cur.execute(sql, (1,)) + cur.execute(sql, (100_000,)) + cur.execute("select num from bug_112 order by num") + assert cur.fetchall() == [(1,), (100_000,)] + + +def test_executemany_type_change(conn): + conn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = conn.cursor() + cur.executemany(sql, [(1,), (100_000,)]) + cur.execute("select num from bug_112 order by num") + assert cur.fetchall() == [(1,), (100_000,)] + + +@pytest.mark.parametrize( + "query", ["copy testcopy from stdin", "copy testcopy to stdout"] +) +def test_execute_copy(conn, query): + cur = conn.cursor() + cur.execute("create table testcopy (id int)") + with pytest.raises(psycopg.ProgrammingError): + cur.execute(query) + + +def test_fetchone(conn): + cur = conn.cursor() + cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert cur.pgresult.fformat(0) == 0 + + row = cur.fetchone() + assert row == (1, "foo", None) + row = cur.fetchone() + assert row is None + + +def test_binary_cursor_execute(conn): + with pytest.raises(psycopg.NotSupportedError): + cur = conn.cursor(binary=True) + cur.execute("select %s, %s", [1, None]) + + +def test_execute_binary(conn): + with pytest.raises(psycopg.NotSupportedError): + cur = conn.cursor() + cur.execute("select %s, %s", [1, None], binary=True) + + +def test_binary_cursor_text_override(conn): + cur = conn.cursor(binary=True) + cur.execute("select %s, %s", [1, None], binary=False) + assert cur.fetchone() == (1, None) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + + +@pytest.mark.parametrize("encoding", ["utf8", "latin9"]) +def test_query_encode(conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + cur = conn.cursor() + (res,) = cur.execute("select '\u20ac'").fetchone() + assert res == "\u20ac" + + +def test_query_badenc(conn): + conn.execute("set client_encoding to latin1") + cur = conn.cursor() + with pytest.raises(UnicodeEncodeError): + cur.execute("select '\u20ac'") + + +@pytest.fixture(scope="session") +def _execmany(svcconn): + cur = svcconn.cursor() + cur.execute( + """ + drop table if exists execmany; + create table execmany (id serial primary key, num integer, data text) + """ + ) + + +@pytest.fixture(scope="function") +def execmany(svcconn, _execmany): + cur = svcconn.cursor() + cur.execute("truncate table execmany") + + +def test_executemany(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + cur.execute("select num, data from execmany order by 1") + assert cur.fetchall() == [(10, "hello"), (20, "world")] + + +def test_executemany_name(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%(num)s, %(data)s)", + [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}], + ) + cur.execute("select num, data from execmany order by 1") + assert cur.fetchall() == [(11, "hello"), (21, "world")] + + +def test_executemany_no_data(conn, execmany): + cur = conn.cursor() + cur.executemany("insert into execmany(num, data) values (%s, %s)", []) + assert cur.rowcount == 0 + + +def test_executemany_rowcount(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + + +def test_executemany_returning(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.fetchone() == (10,) + assert cur.nextset() + assert cur.fetchone() == (20,) + assert cur.nextset() is None + + +def test_executemany_returning_discard(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + assert cur.nextset() is None + + +def test_executemany_no_result(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.statusmessage.startswith("INSERT") + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + pgresult = cur.pgresult + assert cur.nextset() + assert cur.statusmessage.startswith("INSERT") + assert pgresult is not cur.pgresult + assert cur.nextset() is None + + +def test_executemany_rowcount_no_hit(conn, execmany): + cur = conn.cursor() + cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)]) + assert cur.rowcount == 0 + cur.executemany("delete from execmany where id = %s", []) + assert cur.rowcount == 0 + cur.executemany("delete from execmany where id = %s returning num", [(-1,), (-2,)]) + assert cur.rowcount == 0 + + +@pytest.mark.parametrize( + "query", + [ + "insert into nosuchtable values (%s, %s)", + "copy (select %s, %s) to stdout", + "wat (%s, %s)", + ], +) +def test_executemany_badquery(conn, query): + cur = conn.cursor() + with pytest.raises(psycopg.DatabaseError): + cur.executemany(query, [(10, "hello"), (20, "world")]) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_executemany_null_first(conn, fmt_in): + cur = conn.cursor() + cur.execute("create table testmany (a bigint, b bigint)") + cur.executemany( + f"insert into testmany values (%{fmt_in}, %{fmt_in})", + [[1, None], [3, 4]], + ) + with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)): + cur.executemany( + f"insert into testmany values (%{fmt_in}, %{fmt_in})", + [[1, ""], [3, 4]], + ) + + +def test_rowcount(conn): + cur = conn.cursor() + + cur.execute("select 1 from generate_series(1, 0)") + assert cur.rowcount == 0 + + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rowcount == 42 + + cur.execute("create table test_rowcount_notuples (id int primary key)") + assert cur.rowcount == -1 + + cur.execute("insert into test_rowcount_notuples select generate_series(1, 42)") + assert cur.rowcount == 42 + + +def test_rownumber(conn): + cur = conn.cursor() + assert cur.rownumber is None + + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + cur.fetchone() + assert cur.rownumber == 1 + cur.fetchone() + assert cur.rownumber == 2 + cur.fetchmany(10) + assert cur.rownumber == 12 + rns: List[int] = [] + for i in cur: + assert cur.rownumber + rns.append(cur.rownumber) + if len(rns) >= 3: + break + assert rns == [13, 14, 15] + assert len(cur.fetchall()) == 42 - rns[-1] + assert cur.rownumber == 42 + + +def test_iter(conn): + cur = conn.cursor() + cur.execute("select generate_series(1, 3)") + assert list(cur) == [(1,), (2,), (3,)] + + +def test_iter_stop(conn): + cur = conn.cursor() + cur.execute("select generate_series(1, 3)") + for rec in cur: + assert rec == (1,) + break + + for rec in cur: + assert rec == (2,) + break + + assert cur.fetchone() == (3,) + assert list(cur) == [] + + +def test_row_factory(conn): + cur = conn.cursor(row_factory=my_row_factory) + + cur.execute("reset search_path") + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + + cur.execute("select 'foo' as bar") + (r,) = cur.fetchone() + assert r == "FOObar" + + cur.execute("select 'x' as x; select 'y' as y, 'z' as z") + assert cur.fetchall() == [["Xx"]] + assert cur.nextset() + assert cur.fetchall() == [["Yy", "Zz"]] + + cur.scroll(-1) + cur.row_factory = rows.dict_row + assert cur.fetchone() == {"y": "y", "z": "z"} + + +def test_row_factory_none(conn): + cur = conn.cursor(row_factory=None) + assert cur.row_factory is rows.tuple_row + r = cur.execute("select 1 as a, 2 as b").fetchone() + assert type(r) is tuple + assert r == (1, 2) + + +def test_bad_row_factory(conn): + def broken_factory(cur): + 1 / 0 + + cur = conn.cursor(row_factory=broken_factory) + with pytest.raises(ZeroDivisionError): + cur.execute("select 1") + + def broken_maker(cur): + def make_row(seq): + 1 / 0 + + return make_row + + cur = conn.cursor(row_factory=broken_maker) + cur.execute("select 1") + with pytest.raises(ZeroDivisionError): + cur.fetchone() + + +def test_scroll(conn): + cur = conn.cursor() + with pytest.raises(psycopg.ProgrammingError): + cur.scroll(0) + + cur.execute("select generate_series(0,9)") + cur.scroll(2) + assert cur.fetchone() == (2,) + cur.scroll(2) + assert cur.fetchone() == (5,) + cur.scroll(2, mode="relative") + assert cur.fetchone() == (8,) + cur.scroll(-1) + assert cur.fetchone() == (8,) + cur.scroll(-2) + assert cur.fetchone() == (7,) + cur.scroll(2, mode="absolute") + assert cur.fetchone() == (2,) + + # on the boundary + cur.scroll(0, mode="absolute") + assert cur.fetchone() == (0,) + with pytest.raises(IndexError): + cur.scroll(-1, mode="absolute") + + cur.scroll(0, mode="absolute") + with pytest.raises(IndexError): + cur.scroll(-1) + + cur.scroll(9, mode="absolute") + assert cur.fetchone() == (9,) + with pytest.raises(IndexError): + cur.scroll(10, mode="absolute") + + cur.scroll(9, mode="absolute") + with pytest.raises(IndexError): + cur.scroll(1) + + with pytest.raises(ValueError): + cur.scroll(1, "wat") + + +def test_query_params_execute(conn): + cur = conn.cursor() + assert cur._query is None + + cur.execute("select %t, %s::text", [1, None]) + assert cur._query is not None + assert cur._query.query == b"select 1, NULL::text" + assert cur._query.params == (b"1", b"NULL") + + cur.execute("select 1") + assert cur._query.query == b"select 1" + assert not cur._query.params + + with pytest.raises(psycopg.DataError): + cur.execute("select %t::int", ["wat"]) + + assert cur._query.query == b"select 'wat'::int" + assert cur._query.params == (b"'wat'",) + + +def test_query_params_executemany(conn): + cur = conn.cursor() + + cur.executemany("select %t, %t", [[1, 2], [3, 4]]) + assert cur._query.query == b"select 3, 4" + assert cur._query.params == (b"3", b"4") + + +def test_stream(conn): + cur = conn.cursor() + with pytest.raises(psycopg.NotSupportedError): + for rec in cur.stream( + "select i, '2021-01-01'::date + i from generate_series(1, %s) as i", + [2], + ): + pass + + +class TestColumn: + def test_description_attribs(self, conn): + curs = conn.cursor() + curs.execute( + """select + 3.14::decimal(10,2) as pi, + 'hello'::text as hi, + '2010-02-18'::date as now + """ + ) + assert len(curs.description) == 3 + for c in curs.description: + len(c) == 7 # DBAPI happy + for i, a in enumerate( + """ + name type_code display_size internal_size precision scale null_ok + """.split() + ): + assert c[i] == getattr(c, a) + + # Won't fill them up + assert c.null_ok is None + + c = curs.description[0] + assert c.name == "pi" + assert c.type_code == builtins["numeric"].oid + assert c.display_size is None + assert c.internal_size is None + assert c.precision == 10 + assert c.scale == 2 + + c = curs.description[1] + assert c.name == "hi" + assert c.type_code == builtins["text"].oid + assert c.display_size is None + assert c.internal_size is None + assert c.precision is None + assert c.scale is None + + c = curs.description[2] + assert c.name == "now" + assert c.type_code == builtins["date"].oid + assert c.display_size is None + assert c.internal_size == 4 + assert c.precision is None + assert c.scale is None + + def test_description_slice(self, conn): + curs = conn.cursor() + curs.execute("select 1::int as a") + curs.description[0][0:2] == ("a", 23) + + @pytest.mark.parametrize( + "type, precision, scale, dsize, isize", + [ + ("text", None, None, None, None), + ("varchar", None, None, None, None), + ("varchar(42)", None, None, 42, None), + ("int4", None, None, None, 4), + ("numeric", None, None, None, None), + ("numeric(10)", 10, 0, None, None), + ("numeric(10, 3)", 10, 3, None, None), + ("time", None, None, None, 8), + ("time(4)", 4, None, None, 8), + ("time(10)", 6, None, None, 8), + ], + ) + def test_details(self, conn, type, precision, scale, dsize, isize): + cur = conn.cursor() + cur.execute(f"select null::{type}") + col = cur.description[0] + repr(col) + assert col.precision == precision + assert col.scale == scale + assert col.display_size == dsize + assert col.internal_size == isize + + def test_pickle(self, conn): + curs = conn.cursor() + curs.execute( + """select + 3.14::decimal(10,2) as pi, + 'hello'::text as hi, + '2010-02-18'::date as now + """ + ) + description = curs.description + pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL) + unpickled = pickle.loads(pickled) + assert [tuple(d) for d in description] == [tuple(d) for d in unpickled] + + def test_no_col_query(self, conn): + cur = conn.execute("select") + assert cur.description == [] + assert cur.fetchall() == [()] + + def test_description_closed_connection(self, conn): + # If we have reasons to break this test we will (e.g. we really need + # the connection). In #172 it fails just by accident. + cur = conn.execute("select 1::int4 as foo") + conn.close() + assert len(cur.description) == 1 + col = cur.description[0] + assert col.name == "foo" + assert col.type_code == 23 + + def test_name_not_a_name(self, conn): + cur = conn.cursor() + (res,) = cur.execute("""select 'x' as "foo-bar" """).fetchone() + assert res == "x" + assert cur.description[0].name == "foo-bar" + + @pytest.mark.parametrize("encoding", ["utf8", "latin9"]) + def test_name_encode(self, conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + cur = conn.cursor() + (res,) = cur.execute("""select 'x' as "\u20ac" """).fetchone() + assert res == "x" + assert cur.description[0].name == "\u20ac" + + +def test_str(conn): + cur = conn.cursor() + assert "[IDLE]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" in str(cur) + cur.execute("select 1") + assert "[INTRANS]" in str(cur) + assert "[TUPLES_OK]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" not in str(cur) + cur.close() + assert "[closed]" in str(cur) + assert "[INTRANS]" in str(cur) + + +@pytest.mark.slow +@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"]) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +def test_leak(dsn, faker, fetch, row_factory): + faker.choose_schema(ncols=5) + faker.make_records(10) + row_factory = getattr(rows, row_factory) + + def work(): + with psycopg.connect(dsn) as conn: + with psycopg.ClientCursor(conn, row_factory=row_factory) as cur: + cur.execute(faker.drop_stmt) + cur.execute(faker.create_stmt) + with faker.find_insert_problem(conn): + cur.executemany(faker.insert_stmt, faker.records) + + cur.execute(faker.select_stmt) + + if fetch == "one": + while 1: + tmp = cur.fetchone() + if tmp is None: + break + elif fetch == "many": + while 1: + tmp = cur.fetchmany(3) + if not tmp: + break + elif fetch == "all": + cur.fetchall() + elif fetch == "iter": + for rec in cur: + pass + + n = [] + gc_collect() + for i in range(3): + work() + gc_collect() + n.append(len(gc.get_objects())) + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" diff --git a/tests/test_client_cursor_async.py b/tests/test_client_cursor_async.py new file mode 100644 index 000000000..aa0c8f8ef --- /dev/null +++ b/tests/test_client_cursor_async.py @@ -0,0 +1,642 @@ +import gc +import pytest +import weakref +from typing import List + +import psycopg +from psycopg import sql, rows +from psycopg.adapt import PyFormat + +from .utils import gc_collect +from .test_cursor import my_row_factory +from .test_cursor import execmany, _execmany # noqa: F401 + +execmany = execmany # avoid F811 underneath +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +async def aconn(aconn): + aconn.cursor_factory = psycopg.AsyncClientCursor + return aconn + + +async def test_init(aconn): + cur = psycopg.AsyncClientCursor(aconn) + await cur.execute("select 1") + assert (await cur.fetchone()) == (1,) + + aconn.row_factory = rows.dict_row + cur = psycopg.AsyncClientCursor(aconn) + await cur.execute("select 1 as a") + assert (await cur.fetchone()) == {"a": 1} + + +async def test_init_factory(aconn): + cur = psycopg.AsyncClientCursor(aconn, row_factory=rows.dict_row) + await cur.execute("select 1 as a") + assert (await cur.fetchone()) == {"a": 1} + + +async def test_from_cursor_factory(dsn): + async with await psycopg.AsyncConnection.connect( + dsn, cursor_factory=psycopg.AsyncClientCursor + ) as aconn: + cur = aconn.cursor() + assert type(cur) is psycopg.AsyncClientCursor + + await cur.execute("select %s", (1,)) + assert await cur.fetchone() == (1,) + assert cur._query + assert cur._query.query == b"select 1" + + +async def test_close(aconn): + cur = aconn.cursor() + assert not cur.closed + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.execute("select 'foo'") + + await cur.close() + assert cur.closed + + +async def test_cursor_close_fetchone(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + for _ in range(5): + await cur.fetchone() + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchone() + + +async def test_cursor_close_fetchmany(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchmany(2)) == 2 + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchmany(2) + + +async def test_cursor_close_fetchall(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchall()) == 10 + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchall() + + +async def test_context(aconn): + async with aconn.cursor() as cur: + assert not cur.closed + + assert cur.closed + + +@pytest.mark.slow +async def test_weakref(aconn): + cur = aconn.cursor() + w = weakref.ref(cur) + await cur.close() + del cur + gc_collect() + assert w() is None + + +async def test_pgresult(aconn): + cur = aconn.cursor() + await cur.execute("select 1") + assert cur.pgresult + await cur.close() + assert not cur.pgresult + + +async def test_statusmessage(aconn): + cur = aconn.cursor() + assert cur.statusmessage is None + + await cur.execute("select generate_series(1, 10)") + assert cur.statusmessage == "SELECT 10" + + await cur.execute("create table statusmessage ()") + assert cur.statusmessage == "CREATE TABLE" + + with pytest.raises(psycopg.ProgrammingError): + await cur.execute("wat") + assert cur.statusmessage is None + + +async def test_execute_sql(aconn): + cur = aconn.cursor() + await cur.execute(sql.SQL("select {value}").format(value="hello")) + assert await cur.fetchone() == ("hello",) + + +async def test_execute_many_results(aconn): + cur = aconn.cursor() + assert cur.nextset() is None + + rv = await cur.execute("select 'foo'; select generate_series(1,3)") + assert rv is cur + assert (await cur.fetchall()) == [("foo",)] + assert cur.rowcount == 1 + assert cur.nextset() + assert (await cur.fetchall()) == [(1,), (2,), (3,)] + assert cur.rowcount == 3 + assert cur.nextset() is None + + await cur.close() + assert cur.nextset() is None + + +async def test_execute_sequence(aconn): + cur = aconn.cursor() + rv = await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert rv is cur + assert len(cur._results) == 1 + assert cur.pgresult.get_value(0, 0) == b"1" + assert cur.pgresult.get_value(0, 1) == b"foo" + assert cur.pgresult.get_value(0, 2) is None + assert cur.nextset() is None + + +@pytest.mark.parametrize("query", ["", " ", ";"]) +async def test_execute_empty_query(aconn, query): + cur = aconn.cursor() + await cur.execute(query) + assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + + +async def test_execute_type_change(aconn): + # issue #112 + await aconn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = aconn.cursor() + await cur.execute(sql, (1,)) + await cur.execute(sql, (100_000,)) + await cur.execute("select num from bug_112 order by num") + assert (await cur.fetchall()) == [(1,), (100_000,)] + + +async def test_executemany_type_change(aconn): + await aconn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = aconn.cursor() + await cur.executemany(sql, [(1,), (100_000,)]) + await cur.execute("select num from bug_112 order by num") + assert (await cur.fetchall()) == [(1,), (100_000,)] + + +@pytest.mark.parametrize( + "query", ["copy testcopy from stdin", "copy testcopy to stdout"] +) +async def test_execute_copy(aconn, query): + cur = aconn.cursor() + await cur.execute("create table testcopy (id int)") + with pytest.raises(psycopg.ProgrammingError): + await cur.execute(query) + + +async def test_fetchone(aconn): + cur = aconn.cursor() + await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert cur.pgresult.fformat(0) == 0 + + row = await cur.fetchone() + assert row == (1, "foo", None) + row = await cur.fetchone() + assert row is None + + +async def test_binary_cursor_execute(aconn): + with pytest.raises(psycopg.NotSupportedError): + cur = aconn.cursor(binary=True) + await cur.execute("select %s, %s", [1, None]) + + +async def test_execute_binary(aconn): + with pytest.raises(psycopg.NotSupportedError): + cur = aconn.cursor() + await cur.execute("select %s, %s", [1, None], binary=True) + + +async def test_binary_cursor_text_override(aconn): + cur = aconn.cursor(binary=True) + await cur.execute("select %s, %s", [1, None], binary=False) + assert (await cur.fetchone()) == (1, None) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + + +@pytest.mark.parametrize("encoding", ["utf8", "latin9"]) +async def test_query_encode(aconn, encoding): + await aconn.execute(f"set client_encoding to {encoding}") + cur = aconn.cursor() + await cur.execute("select '\u20ac'") + (res,) = await cur.fetchone() + assert res == "\u20ac" + + +async def test_query_badenc(aconn): + await aconn.execute("set client_encoding to latin1") + cur = aconn.cursor() + with pytest.raises(UnicodeEncodeError): + await cur.execute("select '\u20ac'") + + +async def test_executemany(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + await cur.execute("select num, data from execmany order by 1") + rv = await cur.fetchall() + assert rv == [(10, "hello"), (20, "world")] + + +async def test_executemany_name(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%(num)s, %(data)s)", + [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}], + ) + await cur.execute("select num, data from execmany order by 1") + rv = await cur.fetchall() + assert rv == [(11, "hello"), (21, "world")] + + +async def test_executemany_no_data(aconn, execmany): + cur = aconn.cursor() + await cur.executemany("insert into execmany(num, data) values (%s, %s)", []) + assert cur.rowcount == 0 + + +async def test_executemany_rowcount(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + + +async def test_executemany_returning(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert (await cur.fetchone()) == (10,) + assert cur.nextset() + assert (await cur.fetchone()) == (20,) + assert cur.nextset() is None + + +async def test_executemany_returning_discard(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + assert cur.nextset() is None + + +async def test_executemany_no_result(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.statusmessage.startswith("INSERT") + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + pgresult = cur.pgresult + assert cur.nextset() + assert cur.statusmessage.startswith("INSERT") + assert pgresult is not cur.pgresult + assert cur.nextset() is None + + +async def test_executemany_rowcount_no_hit(aconn, execmany): + cur = aconn.cursor() + await cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)]) + assert cur.rowcount == 0 + await cur.executemany("delete from execmany where id = %s", []) + assert cur.rowcount == 0 + await cur.executemany( + "delete from execmany where id = %s returning num", [(-1,), (-2,)] + ) + assert cur.rowcount == 0 + + +@pytest.mark.parametrize( + "query", + [ + "insert into nosuchtable values (%s, %s)", + "copy (select %s, %s) to stdout", + "wat (%s, %s)", + ], +) +async def test_executemany_badquery(aconn, query): + cur = aconn.cursor() + with pytest.raises(psycopg.DatabaseError): + await cur.executemany(query, [(10, "hello"), (20, "world")]) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +async def test_executemany_null_first(aconn, fmt_in): + cur = aconn.cursor() + await cur.execute("create table testmany (a bigint, b bigint)") + await cur.executemany( + f"insert into testmany values (%{fmt_in}, %{fmt_in})", + [[1, None], [3, 4]], + ) + with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)): + await cur.executemany( + f"insert into testmany values (%{fmt_in}, %{fmt_in})", + [[1, ""], [3, 4]], + ) + + +async def test_rowcount(aconn): + cur = aconn.cursor() + + await cur.execute("select 1 from generate_series(1, 0)") + assert cur.rowcount == 0 + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rowcount == 42 + + await cur.execute("create table test_rowcount_notuples (id int primary key)") + assert cur.rowcount == -1 + + await cur.execute( + "insert into test_rowcount_notuples select generate_series(1, 42)" + ) + assert cur.rowcount == 42 + + +async def test_rownumber(aconn): + cur = aconn.cursor() + assert cur.rownumber is None + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + await cur.fetchone() + assert cur.rownumber == 1 + await cur.fetchone() + assert cur.rownumber == 2 + await cur.fetchmany(10) + assert cur.rownumber == 12 + rns: List[int] = [] + async for i in cur: + assert cur.rownumber + rns.append(cur.rownumber) + if len(rns) >= 3: + break + assert rns == [13, 14, 15] + assert len(await cur.fetchall()) == 42 - rns[-1] + assert cur.rownumber == 42 + + +async def test_iter(aconn): + cur = aconn.cursor() + await cur.execute("select generate_series(1, 3)") + res = [] + async for rec in cur: + res.append(rec) + assert res == [(1,), (2,), (3,)] + + +async def test_iter_stop(aconn): + cur = aconn.cursor() + await cur.execute("select generate_series(1, 3)") + async for rec in cur: + assert rec == (1,) + break + + async for rec in cur: + assert rec == (2,) + break + + assert (await cur.fetchone()) == (3,) + async for rec in cur: + assert False + + +async def test_row_factory(aconn): + cur = aconn.cursor(row_factory=my_row_factory) + await cur.execute("select 'foo' as bar") + (r,) = await cur.fetchone() + assert r == "FOObar" + + await cur.execute("select 'x' as x; select 'y' as y, 'z' as z") + assert await cur.fetchall() == [["Xx"]] + assert cur.nextset() + assert await cur.fetchall() == [["Yy", "Zz"]] + + await cur.scroll(-1) + cur.row_factory = rows.dict_row + assert await cur.fetchone() == {"y": "y", "z": "z"} + + +async def test_row_factory_none(aconn): + cur = aconn.cursor(row_factory=None) + assert cur.row_factory is rows.tuple_row + await cur.execute("select 1 as a, 2 as b") + r = await cur.fetchone() + assert type(r) is tuple + assert r == (1, 2) + + +async def test_bad_row_factory(aconn): + def broken_factory(cur): + 1 / 0 + + cur = aconn.cursor(row_factory=broken_factory) + with pytest.raises(ZeroDivisionError): + await cur.execute("select 1") + + def broken_maker(cur): + def make_row(seq): + 1 / 0 + + return make_row + + cur = aconn.cursor(row_factory=broken_maker) + await cur.execute("select 1") + with pytest.raises(ZeroDivisionError): + await cur.fetchone() + + +async def test_scroll(aconn): + cur = aconn.cursor() + with pytest.raises(psycopg.ProgrammingError): + await cur.scroll(0) + + await cur.execute("select generate_series(0,9)") + await cur.scroll(2) + assert await cur.fetchone() == (2,) + await cur.scroll(2) + assert await cur.fetchone() == (5,) + await cur.scroll(2, mode="relative") + assert await cur.fetchone() == (8,) + await cur.scroll(-1) + assert await cur.fetchone() == (8,) + await cur.scroll(-2) + assert await cur.fetchone() == (7,) + await cur.scroll(2, mode="absolute") + assert await cur.fetchone() == (2,) + + # on the boundary + await cur.scroll(0, mode="absolute") + assert await cur.fetchone() == (0,) + with pytest.raises(IndexError): + await cur.scroll(-1, mode="absolute") + + await cur.scroll(0, mode="absolute") + with pytest.raises(IndexError): + await cur.scroll(-1) + + await cur.scroll(9, mode="absolute") + assert await cur.fetchone() == (9,) + with pytest.raises(IndexError): + await cur.scroll(10, mode="absolute") + + await cur.scroll(9, mode="absolute") + with pytest.raises(IndexError): + await cur.scroll(1) + + with pytest.raises(ValueError): + await cur.scroll(1, "wat") + + +async def test_query_params_execute(aconn): + cur = aconn.cursor() + assert cur._query is None + + await cur.execute("select %t, %s::text", [1, None]) + assert cur._query is not None + assert cur._query.query == b"select 1, NULL::text" + assert cur._query.params == (b"1", b"NULL") + + await cur.execute("select 1") + assert cur._query.query == b"select 1" + assert not cur._query.params + + with pytest.raises(psycopg.DataError): + await cur.execute("select %t::int", ["wat"]) + + assert cur._query.query == b"select 'wat'::int" + assert cur._query.params == (b"'wat'",) + + +async def test_query_params_executemany(aconn): + cur = aconn.cursor() + + await cur.executemany("select %t, %t", [[1, 2], [3, 4]]) + assert cur._query.query == b"select 3, 4" + assert cur._query.params == (b"3", b"4") + + +async def test_stream(aconn): + cur = aconn.cursor() + with pytest.raises(psycopg.NotSupportedError): + async for rec in cur.stream( + "select i, '2021-01-01'::date + i from generate_series(1, %s) as i", + [2], + ): + pass + + +async def test_str(aconn): + cur = aconn.cursor() + assert "[IDLE]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" in str(cur) + await cur.execute("select 1") + assert "[INTRANS]" in str(cur) + assert "[TUPLES_OK]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" not in str(cur) + await cur.close() + assert "[closed]" in str(cur) + assert "[INTRANS]" in str(cur) + + +@pytest.mark.slow +@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"]) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_leak(dsn, faker, fetch, row_factory): + faker.choose_schema(ncols=5) + faker.make_records(10) + row_factory = getattr(rows, row_factory) + + async def work(): + async with await psycopg.AsyncConnection.connect(dsn) as conn: + async with psycopg.AsyncClientCursor(conn, row_factory=row_factory) as cur: + await cur.execute(faker.drop_stmt) + await cur.execute(faker.create_stmt) + async with faker.find_insert_problem_async(conn): + await cur.executemany(faker.insert_stmt, faker.records) + await cur.execute(faker.select_stmt) + + if fetch == "one": + while 1: + tmp = await cur.fetchone() + if tmp is None: + break + elif fetch == "many": + while 1: + tmp = await cur.fetchmany(3) + if not tmp: + break + elif fetch == "all": + await cur.fetchall() + elif fetch == "iter": + async for rec in cur: + pass + + n = [] + gc_collect() + for i in range(3): + await work() + gc_collect() + n.append(len(gc.get_objects())) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"