]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(tests): tweak test_cursor to make it more similar to the async version
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 7 Aug 2023 08:58:05 +0000 (09:58 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
tests/_test_cursor.py [new file with mode: 0644]
tests/test_column.py [new file with mode: 0644]
tests/test_connection.py
tests/test_connection_async.py
tests/test_cursor.py
tests/test_cursor_async.py
tests/test_raw_cursor.py
tests/test_raw_cursor_async.py
tests/utils.py

diff --git a/tests/_test_cursor.py b/tests/_test_cursor.py
new file mode 100644 (file)
index 0000000..69c5063
--- /dev/null
@@ -0,0 +1,60 @@
+"""
+Support module for test_cursor[_async].py
+"""
+
+import re
+from typing import Any, List, Match, Union
+
+import pytest
+import psycopg
+from psycopg.rows import RowMaker
+
+
+@pytest.fixture(scope="session")
+def _execmany(svcconn):
+    cur = svcconn.cursor()
+    cur.execute(
+        """
+        drop table if exists execmany;
+        create table execmany (id serial primary key, num integer, data text)
+        """
+    )
+
+
+@pytest.fixture(scope="function")
+def execmany(svcconn, _execmany):
+    cur = svcconn.cursor()
+    cur.execute("truncate table execmany")
+
+
+def ph(cur: Any, query: str) -> str:
+    """Change placeholders in a query from %s to $n if testing  a raw cursor"""
+    if not isinstance(cur, (psycopg.RawCursor, psycopg.AsyncRawCursor)):
+        return query
+
+    if "%(" in query:
+        raise pytest.skip("RawCursor only supports positional placeholders")
+
+    n = 1
+
+    def s(m: Match[str]) -> str:
+        nonlocal n
+        rv = f"${n}"
+        n += 1
+        return rv
+
+    return re.sub(r"(?<!%)(%[bst])", s, query)
+
+
+def my_row_factory(
+    cursor: Union[psycopg.Cursor[List[str]], psycopg.AsyncCursor[List[str]]]
+) -> RowMaker[List[str]]:
+    if cursor.description is not None:
+        titles = [c.name for c in cursor.description]
+
+        def mkrow(values):
+            return [f"{value.upper()}{title}" for title, value in zip(titles, values)]
+
+        return mkrow
+    else:
+        return psycopg.rows.no_result
diff --git a/tests/test_column.py b/tests/test_column.py
new file mode 100644 (file)
index 0000000..d31181b
--- /dev/null
@@ -0,0 +1,137 @@
+import pickle
+
+import pytest
+
+from psycopg.postgres import types as builtins
+from .fix_crdb import is_crdb, crdb_encoding, crdb_time_precision
+
+
+def test_description_attribs(conn):
+    curs = conn.cursor()
+    curs.execute(
+        """select
+        3.14::decimal(10,2) as pi,
+        'hello'::text as hi,
+        '2010-02-18'::date as now
+        """
+    )
+    assert len(curs.description) == 3
+    for c in curs.description:
+        len(c) == 7  # DBAPI happy
+        for i, a in enumerate(
+            """
+            name type_code display_size internal_size precision scale null_ok
+            """.split()
+        ):
+            assert c[i] == getattr(c, a)
+
+        # Won't fill them up
+        assert c.null_ok is None
+
+    c = curs.description[0]
+    assert c.name == "pi"
+    assert c.type_code == builtins["numeric"].oid
+    assert c.display_size is None
+    assert c.internal_size is None
+    assert c.precision == 10
+    assert c.scale == 2
+
+    c = curs.description[1]
+    assert c.name == "hi"
+    assert c.type_code == builtins["text"].oid
+    assert c.display_size is None
+    assert c.internal_size is None
+    assert c.precision is None
+    assert c.scale is None
+
+    c = curs.description[2]
+    assert c.name == "now"
+    assert c.type_code == builtins["date"].oid
+    assert c.display_size is None
+    if is_crdb(conn) and conn.info.server_version < 230000:
+        assert c.internal_size == 16
+    else:
+        assert c.internal_size == 4
+    assert c.precision is None
+    assert c.scale is None
+
+
+def test_description_slice(conn):
+    curs = conn.cursor()
+    curs.execute("select 1::int as a")
+    curs.description[0][0:2] == ("a", 23)
+
+
+@pytest.mark.parametrize(
+    "type, precision, scale, dsize, isize",
+    [
+        ("text", None, None, None, None),
+        ("varchar", None, None, None, None),
+        ("varchar(42)", None, None, 42, None),
+        ("int4", None, None, None, 4),
+        ("numeric", None, None, None, None),
+        ("numeric(10)", 10, 0, None, None),
+        ("numeric(10, 3)", 10, 3, None, None),
+        ("time", None, None, None, 8),
+        crdb_time_precision("time(4)", 4, None, None, 8),
+        crdb_time_precision("time(10)", 6, None, None, 8),
+    ],
+)
+def test_details(conn, type, precision, scale, dsize, isize):
+    cur = conn.cursor()
+    cur.execute(f"select null::{type}")
+    col = cur.description[0]
+    repr(col)
+    assert col.precision == precision
+    assert col.scale == scale
+    assert col.display_size == dsize
+    assert col.internal_size == isize
+
+
+def test_pickle(conn):
+    curs = conn.cursor()
+    curs.execute(
+        """select
+        3.14::decimal(10,2) as pi,
+        'hello'::text as hi,
+        '2010-02-18'::date as now
+        """
+    )
+    description = curs.description
+    pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL)
+    unpickled = pickle.loads(pickled)
+    assert [tuple(d) for d in description] == [tuple(d) for d in unpickled]
+
+
+@pytest.mark.crdb_skip("no col query")
+def test_no_col_query(conn):
+    cur = conn.execute("select")
+    assert cur.description == []
+    assert cur.fetchall() == [()]
+
+
+def test_description_closed_connection(conn):
+    # If we have reasons to break this test we will (e.g. we really need
+    # the connection). In #172 it fails just by accident.
+    cur = conn.execute("select 1::int4 as foo")
+    conn.close()
+    assert len(cur.description) == 1
+    col = cur.description[0]
+    assert col.name == "foo"
+    assert col.type_code == 23
+
+
+def test_name_not_a_name(conn):
+    cur = conn.cursor()
+    (res,) = cur.execute("""select 'x' as "foo-bar" """).fetchone()
+    assert res == "x"
+    assert cur.description[0].name == "foo-bar"
+
+
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+def test_name_encode(conn, encoding):
+    conn.execute(f"set client_encoding to {encoding}")
+    cur = conn.cursor()
+    (res,) = cur.execute("""select 'x' as "\u20ac" """).fetchone()
+    assert res == "x"
+    assert cur.description[0].name == "\u20ac"
index 7314f6f31a93b6e676ac485728caf1d7eaabe03d..6092bebebdc4ffb3fb64faddc42fc25fe5086820 100644 (file)
@@ -12,7 +12,7 @@ from psycopg.rows import tuple_row
 from psycopg.conninfo import conninfo_to_dict, make_conninfo
 
 from .utils import gc_collect
-from .test_cursor import my_row_factory
+from ._test_cursor import my_row_factory
 from .test_adapt import make_bin_dumper, make_dumper
 
 
index 61277872f0da48967dc68f727fdd86ba67fe3e6b..b58b3e5851f237942053fe768551cbf026c1bd73 100644 (file)
@@ -10,7 +10,7 @@ from psycopg.rows import tuple_row
 from psycopg.conninfo import conninfo_to_dict, make_conninfo
 
 from .utils import gc_collect
-from .test_cursor import my_row_factory
+from ._test_cursor import my_row_factory
 from .test_connection import tx_params, tx_params_isolation, tx_values_map
 from .test_connection import conninfo_params_timeout
 from .test_connection import testctx  # noqa: F401  # fixture
index c0d57dac1ce48824a572ebee077f3908cec258ca..326f699fe261d0c3afed3499d018897228e8e4d3 100644 (file)
@@ -2,24 +2,24 @@
 Tests common to psycopg.Cursor and its subclasses.
 """
 
-import re
-import pickle
 import weakref
 import datetime as dt
-from typing import Any, List, Match, Union
+from typing import Any, List
 from contextlib import closing
 
 import pytest
 
 import psycopg
 from psycopg import sql, rows
-from psycopg.rows import RowMaker
 from psycopg.adapt import PyFormat
 from psycopg.types import TypeInfo
-from psycopg.postgres import types as builtins
 
 from .utils import gc_collect, raiseif
-from .fix_crdb import is_crdb, crdb_encoding, crdb_time_precision
+from .fix_crdb import crdb_encoding
+from ._test_cursor import my_row_factory, ph
+from ._test_cursor import execmany, _execmany  # noqa: F401
+
+execmany = execmany  # avoid F811 underneath
 
 
 @pytest.fixture(params=[psycopg.Cursor, psycopg.ClientCursor, psycopg.RawCursor])
@@ -28,25 +28,6 @@ def conn(conn, request):
     return conn
 
 
-def ph(cur: Any, query: str) -> str:
-    """Change placeholders in a query from %s to $n if testing  a raw cursor"""
-    if not isinstance(cur, (psycopg.RawCursor, psycopg.AsyncRawCursor)):
-        return query
-
-    if "%(" in query:
-        raise pytest.skip("RawCursor only supports positional placeholders")
-
-    n = 1
-
-    def s(m: Match[str]) -> str:
-        nonlocal n
-        rv = f"${n}"
-        n += 1
-        return rv
-
-    return re.sub(r"(?<!%)(%[bst])", s, query)
-
-
 def test_init(conn):
     cur = conn.cursor_factory(conn)
     cur.execute("select 1")
@@ -335,23 +316,6 @@ def test_query_badenc(conn, encoding):
         cur.execute("select '\u20ac'")
 
 
-@pytest.fixture(scope="session")
-def _execmany(svcconn):
-    cur = svcconn.cursor()
-    cur.execute(
-        """
-        drop table if exists execmany;
-        create table execmany (id serial primary key, num integer, data text)
-        """
-    )
-
-
-@pytest.fixture(scope="function")
-def execmany(svcconn, _execmany):
-    cur = svcconn.cursor()
-    cur.execute("truncate table execmany")
-
-
 def test_executemany(conn, execmany):
     cur = conn.cursor()
     cur.executemany(
@@ -700,7 +664,7 @@ def test_stream_sql(conn):
 
 def test_stream_row_factory(conn):
     cur = conn.cursor(row_factory=rows.dict_row)
-    it = iter(cur.stream("select generate_series(1,2) as a"))
+    it = cur.stream("select generate_series(1,2) as a")
     assert next(it)["a"] == 1
     cur.row_factory = rows.namedtuple_row
     assert next(it).a == 2
@@ -828,131 +792,6 @@ def test_stream_binary_cursor_text_override(conn):
     assert recs == [(1,), (2,)]
 
 
-class TestColumn:
-    def test_description_attribs(self, conn):
-        curs = conn.cursor()
-        curs.execute(
-            """select
-            3.14::decimal(10,2) as pi,
-            'hello'::text as hi,
-            '2010-02-18'::date as now
-            """
-        )
-        assert len(curs.description) == 3
-        for c in curs.description:
-            len(c) == 7  # DBAPI happy
-            for i, a in enumerate(
-                """
-                name type_code display_size internal_size precision scale null_ok
-                """.split()
-            ):
-                assert c[i] == getattr(c, a)
-
-            # Won't fill them up
-            assert c.null_ok is None
-
-        c = curs.description[0]
-        assert c.name == "pi"
-        assert c.type_code == builtins["numeric"].oid
-        assert c.display_size is None
-        assert c.internal_size is None
-        assert c.precision == 10
-        assert c.scale == 2
-
-        c = curs.description[1]
-        assert c.name == "hi"
-        assert c.type_code == builtins["text"].oid
-        assert c.display_size is None
-        assert c.internal_size is None
-        assert c.precision is None
-        assert c.scale is None
-
-        c = curs.description[2]
-        assert c.name == "now"
-        assert c.type_code == builtins["date"].oid
-        assert c.display_size is None
-        if is_crdb(conn) and conn.info.server_version < 230000:
-            assert c.internal_size == 16
-        else:
-            assert c.internal_size == 4
-        assert c.precision is None
-        assert c.scale is None
-
-    def test_description_slice(self, conn):
-        curs = conn.cursor()
-        curs.execute("select 1::int as a")
-        curs.description[0][0:2] == ("a", 23)
-
-    @pytest.mark.parametrize(
-        "type, precision, scale, dsize, isize",
-        [
-            ("text", None, None, None, None),
-            ("varchar", None, None, None, None),
-            ("varchar(42)", None, None, 42, None),
-            ("int4", None, None, None, 4),
-            ("numeric", None, None, None, None),
-            ("numeric(10)", 10, 0, None, None),
-            ("numeric(10, 3)", 10, 3, None, None),
-            ("time", None, None, None, 8),
-            crdb_time_precision("time(4)", 4, None, None, 8),
-            crdb_time_precision("time(10)", 6, None, None, 8),
-        ],
-    )
-    def test_details(self, conn, type, precision, scale, dsize, isize):
-        cur = conn.cursor()
-        cur.execute(f"select null::{type}")
-        col = cur.description[0]
-        repr(col)
-        assert col.precision == precision
-        assert col.scale == scale
-        assert col.display_size == dsize
-        assert col.internal_size == isize
-
-    def test_pickle(self, conn):
-        curs = conn.cursor()
-        curs.execute(
-            """select
-            3.14::decimal(10,2) as pi,
-            'hello'::text as hi,
-            '2010-02-18'::date as now
-            """
-        )
-        description = curs.description
-        pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL)
-        unpickled = pickle.loads(pickled)
-        assert [tuple(d) for d in description] == [tuple(d) for d in unpickled]
-
-    @pytest.mark.crdb_skip("no col query")
-    def test_no_col_query(self, conn):
-        cur = conn.execute("select")
-        assert cur.description == []
-        assert cur.fetchall() == [()]
-
-    def test_description_closed_connection(self, conn):
-        # If we have reasons to break this test we will (e.g. we really need
-        # the connection). In #172 it fails just by accident.
-        cur = conn.execute("select 1::int4 as foo")
-        conn.close()
-        assert len(cur.description) == 1
-        col = cur.description[0]
-        assert col.name == "foo"
-        assert col.type_code == 23
-
-    def test_name_not_a_name(self, conn):
-        cur = conn.cursor()
-        (res,) = cur.execute("""select 'x' as "foo-bar" """).fetchone()
-        assert res == "x"
-        assert cur.description[0].name == "foo-bar"
-
-    @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
-    def test_name_encode(self, conn, encoding):
-        conn.execute(f"set client_encoding to {encoding}")
-        cur = conn.cursor()
-        (res,) = cur.execute("""select 'x' as "\u20ac" """).fetchone()
-        assert res == "x"
-        assert cur.description[0].name == "\u20ac"
-
-
 def test_str(conn):
     cur = conn.cursor()
     assert "[IDLE]" in str(cur)
@@ -985,17 +824,3 @@ def test_message_0x33(conn):
 def test_typeinfo(conn):
     info = TypeInfo.fetch(conn, "jsonb")
     assert info is not None
-
-
-def my_row_factory(
-    cursor: Union[psycopg.Cursor[List[str]], psycopg.AsyncCursor[List[str]]]
-) -> RowMaker[List[str]]:
-    if cursor.description is not None:
-        titles = [c.name for c in cursor.description]
-
-        def mkrow(values):
-            return [f"{value.upper()}{title}" for title, value in zip(titles, values)]
-
-        return mkrow
-    else:
-        return rows.no_result
index df70f0130be36c72e656570244e3f6aabe3da074..155d23ad9d6e0b7fdf2f6f44acd18c5f256772c7 100644 (file)
@@ -1,17 +1,24 @@
-import pytest
+"""
+Tests common to psycopg.AsyncCursor and its subclasses.
+"""
+
 import weakref
 import datetime as dt
-from typing import List
+from typing import Any, List
+from contextlib import aclosing
+
+import pytest
 
 import psycopg
 from psycopg import sql, rows
 from psycopg.adapt import PyFormat
 from psycopg.types import TypeInfo
 
-from .utils import alist, gc_collect, raiseif
-from .test_cursor import my_row_factory, ph
-from .test_cursor import execmany, _execmany  # noqa: F401
+from .utils import alist, anext
+from .utils import gc_collect, raiseif
 from .fix_crdb import crdb_encoding
+from ._test_cursor import my_row_factory, ph
+from ._test_cursor import execmany, _execmany  # noqa: F401
 
 execmany = execmany  # avoid F811 underneath
 
@@ -19,7 +26,7 @@ execmany = execmany  # avoid F811 underneath
 @pytest.fixture(
     params=[psycopg.AsyncCursor, psycopg.AsyncClientCursor, psycopg.AsyncRawCursor]
 )
-def aconn(aconn, request, anyio_backend):
+async def aconn(aconn, request, anyio_backend):
     aconn.cursor_factory = request.param
     return aconn
 
@@ -146,6 +153,42 @@ async def test_execute_sql(aconn):
     assert (await cur.fetchone()) == ("hello",)
 
 
+async def test_query_parse_cache_size(aconn):
+    cur = aconn.cursor()
+    cls = type(cur)
+
+    # Warning: testing internal structures. Test might need refactoring with the code.
+    cache: Any
+    if cls is psycopg.AsyncCursor:
+        cache = psycopg._queries._query2pg
+    elif cls is psycopg.AsyncClientCursor:
+        cache = psycopg._queries._query2pg_client
+    elif cls is psycopg.AsyncRawCursor:
+        pytest.skip("RawCursor has no query parse cache")
+    else:
+        assert False, cls
+
+    cache.cache_clear()
+    ci = cache.cache_info()
+    h0, m0 = ci.hits, ci.misses
+    tests = [
+        (f"select 1 -- {'x' * 3500}", (), h0, m0 + 1),
+        (f"select 1 -- {'x' * 3500}", (), h0 + 1, m0 + 1),
+        (f"select 1 -- {'x' * 4500}", (), h0 + 1, m0 + 1),
+        (f"select 1 -- {'x' * 4500}", (), h0 + 1, m0 + 1),
+        (f"select 1 -- {'%s' * 40}", ("x",) * 40, h0 + 1, m0 + 2),
+        (f"select 1 -- {'%s' * 40}", ("x",) * 40, h0 + 2, m0 + 2),
+        (f"select 1 -- {'%s' * 60}", ("x",) * 60, h0 + 2, m0 + 2),
+        (f"select 1 -- {'%s' * 60}", ("x",) * 60, h0 + 2, m0 + 2),
+    ]
+    for i, (query, params, hits, misses) in enumerate(tests):
+        pq = cur._query_cls(psycopg.adapt.Transformer())
+        pq.convert(query, params)
+        ci = cache.cache_info()
+        assert ci.hits == hits, f"at {i}"
+        assert ci.misses == misses, f"at {i}"
+
+
 async def test_execute_many_results(aconn):
     cur = aconn.cursor()
     assert cur.nextset() is None
@@ -459,13 +502,12 @@ async def test_rownumber_none(aconn, query):
 
 async def test_rownumber_mixed(aconn):
     cur = aconn.cursor()
-    await cur.execute(
-        """
-select x from generate_series(1, 3) x;
-set timezone to utc;
-select x from generate_series(4, 6) x;
-"""
-    )
+    queries = [
+        "select x from generate_series(1, 3) x",
+        "set timezone to utc",
+        "select x from generate_series(4, 6) x",
+    ]
+    await cur.execute(";\n".join(queries))
     assert cur.rownumber == 0
     assert await cur.fetchone() == (1,)
     assert cur.rownumber == 1
@@ -497,8 +539,7 @@ async def test_iter_stop(aconn):
         break
 
     assert (await cur.fetchone()) == (3,)
-    async for rec in cur:
-        assert False
+    assert (await alist(cur)) == []
 
 
 async def test_row_factory(aconn):
@@ -635,22 +676,22 @@ async def test_stream_sql(aconn):
 
 async def test_stream_row_factory(aconn):
     cur = aconn.cursor(row_factory=rows.dict_row)
-    ait = cur.stream("select generate_series(1,2) as a")
-    assert (await ait.__anext__())["a"] == 1
+    it = cur.stream("select generate_series(1,2) as a")
+    assert (await anext(it))["a"] == 1
     cur.row_factory = rows.namedtuple_row
-    assert (await ait.__anext__()).a == 2
+    assert (await anext(it)).a == 2
 
 
 async def test_stream_no_row(aconn):
     cur = aconn.cursor()
-    recs = [rec async for rec in cur.stream("select generate_series(2,1) as a")]
+    recs = await alist(cur.stream("select generate_series(2,1) as a"))
     assert recs == []
 
 
 @pytest.mark.crdb_skip("no col query")
 async def test_stream_no_col(aconn):
     cur = aconn.cursor()
-    recs = [rec async for rec in cur.stream("select")]
+    recs = await alist(cur.stream("select"))
     assert recs == [()]
 
 
@@ -689,11 +730,9 @@ async def test_stream_error_notx(aconn):
 async def test_stream_error_python_to_consume(aconn):
     cur = aconn.cursor()
     with pytest.raises(ZeroDivisionError):
-        gen = cur.stream("select generate_series(1, 10000)")
-        async for rec in gen:
-            1 / 0
-
-    await gen.aclose()
+        async with aclosing(cur.stream("select generate_series(1, 10000)")) as gen:
+            async for rec in gen:
+                1 / 0
     assert aconn.info.transaction_status in (
         aconn.TransactionStatus.INTRANS,
         aconn.TransactionStatus.INERROR,
index 64169dd2feb9ac7cae008a290fef732af28f01e2..fd6fe9bc5854117c2444b563eda61bc61aca8937 100644 (file)
@@ -3,7 +3,7 @@ import psycopg
 from psycopg import pq, rows, errors as e
 from psycopg.adapt import PyFormat
 
-from .test_cursor import ph
+from ._test_cursor import ph
 from .utils import gc_collect, gc_count
 
 
index 037ee7ffdbd1005d92930aea2393412c03ed66c4..189a208f18aeb145a307165fe958ed81e6e88cf1 100644 (file)
@@ -3,7 +3,7 @@ import psycopg
 from psycopg import pq, rows, errors as e
 from psycopg.adapt import PyFormat
 
-from .test_cursor import ph
+from ._test_cursor import ph
 from .utils import gc_collect, gc_count
 
 
index 543ee975282ee8dd0ed1e39fb989c192d3d6123e..57252e0dcdfcb12d623a1e2df8a20c6035911ee3 100644 (file)
@@ -180,6 +180,10 @@ async def alist(it):
     return [i async for i in it]
 
 
+async def anext(it):
+    return await it.__anext__()
+
+
 @contextmanager
 def raiseif(cond, *args, **kwargs):
     """