From: Daniele Varrazzo Date: Tue, 7 Jun 2022 01:52:11 +0000 (+0200) Subject: test(crdb): get connection class from fixture to make sure to test CRDB objects X-Git-Tag: 3.1~49^2~24 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=be4227c93f637527a0b63d58f9f269f74a10fedc;p=thirdparty%2Fpsycopg.git test(crdb): get connection class from fixture to make sure to test CRDB objects --- diff --git a/tests/fix_db.py b/tests/fix_db.py index 462732701..c2c8527ee 100644 --- a/tests/fix_db.py +++ b/tests/fix_db.py @@ -7,7 +7,6 @@ from typing import List, Optional import psycopg from psycopg import pq from psycopg import sql -from psycopg.crdb import CrdbConnection from .utils import check_libpq_version, check_server_version @@ -148,11 +147,11 @@ def pgconn(dsn, request, tracefile): @pytest.fixture -def conn(dsn, request, tracefile): +def conn(conn_cls, dsn, request, tracefile): """Return a `Connection` connected to the ``--test-dsn`` database.""" check_connection_version(request.node) - conn = connection_class().connect(dsn) + conn = conn_cls.connect(dsn) with maybe_trace(conn.pgconn, tracefile, request.function): yield conn conn.close() @@ -172,17 +171,11 @@ def pipeline(request, conn): @pytest.fixture -async def aconn(dsn, request, tracefile): +async def aconn(dsn, aconn_cls, request, tracefile): """Return an `AsyncConnection` connected to the ``--test-dsn`` database.""" check_connection_version(request.node) - cls = psycopg.AsyncConnection - if crdb_version: - from psycopg.crdb import AsyncCrdbConnection - - cls = AsyncCrdbConnection - - conn = await cls.connect(dsn) + conn = await aconn_cls.connect(dsn) with maybe_trace(conn.pgconn, tracefile, request.function): yield conn await conn.close() @@ -201,20 +194,34 @@ async def apipeline(request, aconn): yield None -def connection_class(): +@pytest.fixture(scope="session") +def conn_cls(session_dsn): cls = psycopg.Connection if crdb_version: + from psycopg.crdb import CrdbConnection + cls = CrdbConnection return cls @pytest.fixture(scope="session") -def svcconn(session_dsn): +def aconn_cls(session_dsn): + cls = psycopg.AsyncConnection + if crdb_version: + from psycopg.crdb import AsyncCrdbConnection + + cls = AsyncCrdbConnection + + return cls + + +@pytest.fixture(scope="session") +def svcconn(conn_cls, session_dsn): """ Return a session `Connection` connected to the ``--test-dsn`` database. """ - conn = psycopg.Connection.connect(session_dsn, autocommit=True) + conn = conn_cls.connect(session_dsn, autocommit=True) yield conn conn.close() diff --git a/tests/pool/test_null_pool.py b/tests/pool/test_null_pool.py index afbdc7572..e3d64c45c 100644 --- a/tests/pool/test_null_pool.py +++ b/tests/pool/test_null_pool.py @@ -608,9 +608,9 @@ def test_close_no_threads(dsn): assert not t.is_alive() -def test_putconn_no_pool(dsn): +def test_putconn_no_pool(conn_cls, dsn): with NullConnectionPool(dsn) as p: - conn = psycopg.connect(dsn) + conn = conn_cls.connect(dsn) with pytest.raises(ValueError): p.putconn(conn) diff --git a/tests/pool/test_null_pool_async.py b/tests/pool/test_null_pool_async.py index 987a35fb3..f88d76a94 100644 --- a/tests/pool/test_null_pool_async.py +++ b/tests/pool/test_null_pool_async.py @@ -584,9 +584,9 @@ async def test_close_no_tasks(dsn): assert t.done() -async def test_putconn_no_pool(dsn): +async def test_putconn_no_pool(aconn_cls, dsn): async with AsyncNullConnectionPool(dsn) as p: - conn = await psycopg.AsyncConnection.connect(dsn) + conn = await aconn_cls.connect(dsn) with pytest.raises(ValueError): await p.putconn(conn) diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index 13e161b67..22d0bfe8b 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -577,9 +577,9 @@ def test_close_no_threads(dsn): assert not t.is_alive() -def test_putconn_no_pool(dsn): +def test_putconn_no_pool(conn_cls, dsn): with pool.ConnectionPool(dsn, min_size=1) as p: - conn = psycopg.connect(dsn) + conn = conn_cls.connect(dsn) with pytest.raises(ValueError): p.putconn(conn) diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index dc20dbcd8..b04b2566e 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -556,9 +556,9 @@ async def test_close_no_tasks(dsn): assert t.done() -async def test_putconn_no_pool(dsn): +async def test_putconn_no_pool(aconn_cls, dsn): async with pool.AsyncConnectionPool(dsn, min_size=1) as p: - conn = await psycopg.AsyncConnection.connect(dsn) + conn = await aconn_cls.connect(dsn) with pytest.raises(ValueError): await p.putconn(conn) diff --git a/tests/test_adapt.py b/tests/test_adapt.py index 7e76cb489..43424cad6 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -62,10 +62,10 @@ def test_register_dumper_by_class_name(conn): @pytest.mark.crdb("skip", reason="global adapters don't affect crdb") -def test_dump_global_ctx(dsn, global_adapters, pgconn): +def test_dump_global_ctx(conn_cls, dsn, global_adapters, pgconn): psycopg.adapters.register_dumper(MyStr, make_bin_dumper("gb")) psycopg.adapters.register_dumper(MyStr, make_dumper("gt")) - with psycopg.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: cur = conn.execute("select %s", [MyStr("hello")]) assert cur.fetchone() == ("hellogt",) cur = conn.execute("select %b", [MyStr("hello")]) @@ -199,10 +199,10 @@ def test_register_loader_by_type_name(conn): @pytest.mark.crdb("skip", reason="global adapters don't affect crdb") -def test_load_global_ctx(dsn, global_adapters): +def test_load_global_ctx(conn_cls, dsn, global_adapters): psycopg.adapters.register_loader("text", make_loader("gt")) psycopg.adapters.register_loader("text", make_bin_loader("gb")) - with psycopg.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: cur = conn.cursor(binary=False).execute("select 'hello'::text") assert cur.fetchone() == ("hellogt",) cur = conn.cursor(binary=True).execute("select 'hello'::text") diff --git a/tests/test_client_cursor.py b/tests/test_client_cursor.py index bc28df41f..c2cf43e4a 100644 --- a/tests/test_client_cursor.py +++ b/tests/test_client_cursor.py @@ -39,9 +39,9 @@ def test_init_factory(conn): assert cur.fetchone() == {"a": 1} -def test_from_cursor_factory(dsn): - with psycopg.connect(dsn, cursor_factory=psycopg.ClientCursor) as aconn: - cur = aconn.cursor() +def test_from_cursor_factory(conn_cls, dsn): + with conn_cls.connect(dsn, cursor_factory=psycopg.ClientCursor) as conn: + cur = conn.cursor() assert type(cur) is psycopg.ClientCursor cur.execute("select %s", (1,)) @@ -755,13 +755,13 @@ def test_str(conn): @pytest.mark.slow @pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"]) @pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) -def test_leak(dsn, faker, fetch, row_factory): +def test_leak(conn_cls, dsn, faker, fetch, row_factory): faker.choose_schema(ncols=5) faker.make_records(10) row_factory = getattr(rows, row_factory) def work(): - with psycopg.connect(dsn) as conn, conn.transaction(force_rollback=True): + with conn_cls.connect(dsn) as conn, conn.transaction(force_rollback=True): with psycopg.ClientCursor(conn, row_factory=row_factory) as cur: cur.execute(faker.drop_stmt) cur.execute(faker.create_stmt) diff --git a/tests/test_client_cursor_async.py b/tests/test_client_cursor_async.py index 56d744fd4..b73173900 100644 --- a/tests/test_client_cursor_async.py +++ b/tests/test_client_cursor_async.py @@ -40,8 +40,8 @@ async def test_init_factory(aconn): assert (await cur.fetchone()) == {"a": 1} -async def test_from_cursor_factory(dsn): - async with await psycopg.AsyncConnection.connect( +async def test_from_cursor_factory(aconn_cls, dsn): + async with await aconn_cls.connect( dsn, cursor_factory=psycopg.AsyncClientCursor ) as aconn: cur = aconn.cursor() @@ -625,13 +625,13 @@ async def test_str(aconn): @pytest.mark.slow @pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"]) @pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) -async def test_leak(dsn, faker, fetch, row_factory): +async def test_leak(aconn_cls, dsn, faker, fetch, row_factory): faker.choose_schema(ncols=5) faker.make_records(10) row_factory = getattr(rows, row_factory) async def work(): - async with await psycopg.AsyncConnection.connect(dsn) as conn, conn.transaction( + async with await aconn_cls.connect(dsn) as conn, conn.transaction( force_rollback=True ): async with psycopg.AsyncClientCursor(conn, row_factory=row_factory) as cur: diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 3b3f2ae11..18c716127 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -18,9 +18,9 @@ from psycopg import errors as e @pytest.mark.slow -def test_concurrent_execution(dsn): +def test_concurrent_execution(conn_cls, dsn): def worker(): - cnn = psycopg.connect(dsn) + cnn = conn_cls.connect(dsn) cur = cnn.cursor() cur.execute("select pg_sleep(0.5)") cur.close() @@ -110,8 +110,8 @@ t.join() @pytest.mark.slow @pytest.mark.timing @pytest.mark.crdb("skip", reason="notify") -def test_notifies(conn, dsn): - nconn = psycopg.connect(dsn, autocommit=True) +def test_notifies(conn_cls, conn, dsn): + nconn = conn_cls.connect(dsn, autocommit=True) npid = nconn.pgconn.backend_pid def notifier(): @@ -185,8 +185,8 @@ def test_cancel(conn): t.join() -@pytest.mark.crdb("skip", reason="pg_terminate_backend") @pytest.mark.slow +@pytest.mark.crdb("skip", reason="cancel") def test_cancel_stream(conn): errors: List[Exception] = [] @@ -210,14 +210,15 @@ def test_cancel_stream(conn): t.join() +@pytest.mark.crdb("skip", reason="pg_terminate_backend") @pytest.mark.slow -def test_identify_closure(dsn): +def test_identify_closure(conn_cls, dsn): def closer(): time.sleep(0.2) conn2.execute("select pg_terminate_backend(%s)", [conn.pgconn.backend_pid]) - conn = psycopg.connect(dsn) - conn2 = psycopg.connect(dsn) + conn = conn_cls.connect(dsn) + conn2 = conn_cls.connect(dsn) try: t = threading.Thread(target=closer) t.start() diff --git a/tests/test_concurrency_async.py b/tests/test_concurrency_async.py index 39f9de86e..06f660f81 100644 --- a/tests/test_concurrency_async.py +++ b/tests/test_concurrency_async.py @@ -45,9 +45,9 @@ async def test_commit_concurrency(aconn): @pytest.mark.slow -async def test_concurrent_execution(dsn): +async def test_concurrent_execution(aconn_cls, dsn): async def worker(): - cnn = await psycopg.AsyncConnection.connect(dsn) + cnn = await aconn_cls.connect(dsn) cur = cnn.cursor() await cur.execute("select pg_sleep(0.5)") await cur.close() @@ -62,8 +62,8 @@ async def test_concurrent_execution(dsn): @pytest.mark.slow @pytest.mark.timing @pytest.mark.crdb("skip", reason="notify") -async def test_notifies(aconn, dsn): - nconn = await psycopg.AsyncConnection.connect(dsn, autocommit=True) +async def test_notifies(aconn_cls, aconn, dsn): + nconn = await aconn_cls.connect(dsn, autocommit=True) npid = nconn.pgconn.backend_pid async def notifier(): @@ -137,6 +137,7 @@ async def test_cancel(aconn): @pytest.mark.slow +@pytest.mark.crdb("skip", reason="cancel") async def test_cancel_stream(aconn): async def worker(): cur = aconn.cursor() @@ -163,15 +164,15 @@ async def test_cancel_stream(aconn): @pytest.mark.slow @pytest.mark.crdb("skip", reason="pg_terminate_backend") -async def test_identify_closure(dsn): +async def test_identify_closure(aconn_cls, dsn): async def closer(): await asyncio.sleep(0.2) await conn2.execute( "select pg_terminate_backend(%s)", [aconn.pgconn.backend_pid] ) - aconn = await psycopg.AsyncConnection.connect(dsn) - conn2 = await psycopg.AsyncConnection.connect(dsn) + aconn = await aconn_cls.connect(dsn) + conn2 = await aconn_cls.connect(dsn) try: t = create_task(closer()) t0 = time.time() diff --git a/tests/test_connection.py b/tests/test_connection.py index 42a32bbfe..ff085bb1e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -6,7 +6,7 @@ from typing import Any, List from dataclasses import dataclass import psycopg -from psycopg import Connection, Notify, errors as e +from psycopg import Notify, errors as e from psycopg.rows import tuple_row from psycopg.conninfo import conninfo_to_dict, make_conninfo @@ -15,34 +15,34 @@ from .test_cursor import my_row_factory from .test_adapt import make_bin_dumper, make_dumper -def test_connect(dsn): - conn = Connection.connect(dsn) +def test_connect(conn_cls, dsn): + conn = conn_cls.connect(dsn) assert not conn.closed assert conn.pgconn.status == conn.ConnStatus.OK conn.close() -def test_connect_str_subclass(dsn): +def test_connect_str_subclass(conn_cls, dsn): class MyString(str): pass - conn = Connection.connect(MyString(dsn)) + conn = conn_cls.connect(MyString(dsn)) assert not conn.closed assert conn.pgconn.status == conn.ConnStatus.OK conn.close() -def test_connect_bad(): +def test_connect_bad(conn_cls): with pytest.raises(psycopg.OperationalError): - Connection.connect("dbname=nosuchdb") + conn_cls.connect("dbname=nosuchdb") @pytest.mark.slow @pytest.mark.timing -def test_connect_timeout(deaf_port): +def test_connect_timeout(conn_cls, deaf_port): t0 = time.time() with pytest.raises(psycopg.OperationalError, match="timeout expired"): - Connection.connect(host="localhost", port=deaf_port, connect_timeout=1) + conn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1) elapsed = time.time() - t0 assert elapsed == pytest.approx(1.0, abs=0.05) @@ -86,22 +86,22 @@ def test_cursor_closed(conn): conn.cursor() -def test_connection_warn_close(dsn, recwarn): - conn = Connection.connect(dsn) +def test_connection_warn_close(conn_cls, dsn, recwarn): + conn = conn_cls.connect(dsn) conn.close() del conn assert not recwarn, [str(w.message) for w in recwarn.list] - conn = Connection.connect(dsn) + conn = conn_cls.connect(dsn) del conn assert "IDLE" in str(recwarn.pop(ResourceWarning).message) - conn = Connection.connect(dsn) + conn = conn_cls.connect(dsn) conn.execute("select 1") del conn assert "INTRANS" in str(recwarn.pop(ResourceWarning).message) - conn = Connection.connect(dsn) + conn = conn_cls.connect(dsn) try: conn.execute("select wat") except Exception: @@ -109,7 +109,7 @@ def test_connection_warn_close(dsn, recwarn): del conn assert "INERROR" in str(recwarn.pop(ResourceWarning).message) - with Connection.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: pass del conn assert not recwarn, [str(w.message) for w in recwarn.list] @@ -122,7 +122,7 @@ def testctx(svcconn): return None -def test_context_commit(testctx, conn, dsn): +def test_context_commit(conn_cls, testctx, conn, dsn): with conn: with conn.cursor() as cur: cur.execute("insert into testctx values (42)") @@ -130,13 +130,13 @@ def test_context_commit(testctx, conn, dsn): assert conn.closed assert not conn.broken - with psycopg.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: with conn.cursor() as cur: cur.execute("select * from testctx") assert cur.fetchall() == [(42,)] -def test_context_rollback(testctx, conn, dsn): +def test_context_rollback(conn_cls, testctx, conn, dsn): with pytest.raises(ZeroDivisionError): with conn: with conn.cursor() as cur: @@ -146,7 +146,7 @@ def test_context_rollback(testctx, conn, dsn): assert conn.closed assert not conn.broken - with psycopg.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: with conn.cursor() as cur: cur.execute("select * from testctx") assert cur.fetchall() == [] @@ -159,11 +159,11 @@ def test_context_close(conn): @pytest.mark.crdb("skip", reason="pg_terminate_backend") -def test_context_inerror_rollback_no_clobber(conn, dsn, caplog): +def test_context_inerror_rollback_no_clobber(conn_cls, conn, dsn, caplog): caplog.set_level(logging.WARNING, logger="psycopg") with pytest.raises(ZeroDivisionError): - with psycopg.connect(dsn) as conn2: + with conn_cls.connect(dsn) as conn2: conn2.execute("select 1") conn.execute( "select pg_terminate_backend(%s::int)", @@ -178,11 +178,11 @@ def test_context_inerror_rollback_no_clobber(conn, dsn, caplog): @pytest.mark.crdb("skip", reason="copy") -def test_context_active_rollback_no_clobber(dsn, caplog): +def test_context_active_rollback_no_clobber(conn_cls, dsn, caplog): caplog.set_level(logging.WARNING, logger="psycopg") with pytest.raises(ZeroDivisionError): - with psycopg.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout") assert not conn.pgconn.error_message status = conn.info.transaction_status @@ -196,8 +196,8 @@ def test_context_active_rollback_no_clobber(dsn, caplog): @pytest.mark.slow -def test_weakref(dsn): - conn = psycopg.connect(dsn) +def test_weakref(conn_cls, dsn): + conn = conn_cls.connect(dsn) w = weakref.ref(conn) conn.close() del conn @@ -310,8 +310,8 @@ def test_autocommit(conn): assert conn.autocommit is True -def test_autocommit_connect(dsn): - conn = Connection.connect(dsn, autocommit=True) +def test_autocommit_connect(conn_cls, dsn): + conn = conn_cls.connect(dsn, autocommit=True) assert conn.autocommit conn.close() @@ -358,7 +358,7 @@ def test_autocommit_unknown(conn): (("host=foo",), {"user": None}, "host=foo"), ], ) -def test_connect_args(monkeypatch, pgconn, args, kwargs, want): +def test_connect_args(conn_cls, monkeypatch, pgconn, args, kwargs, want): the_conninfo: str def fake_connect(conninfo): @@ -368,7 +368,7 @@ def test_connect_args(monkeypatch, pgconn, args, kwargs, want): yield monkeypatch.setattr(psycopg.connection, "connect", fake_connect) - conn = psycopg.Connection.connect(*args, **kwargs) + conn = conn_cls.connect(*args, **kwargs) assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) conn.close() @@ -381,14 +381,14 @@ def test_connect_args(monkeypatch, pgconn, args, kwargs, want): ((), {"nosuchparam": 42}, psycopg.ProgrammingError), ], ) -def test_connect_badargs(monkeypatch, pgconn, args, kwargs, exctype): +def test_connect_badargs(conn_cls, monkeypatch, pgconn, args, kwargs, exctype): def fake_connect(conninfo): return pgconn yield monkeypatch.setattr(psycopg.connection, "connect", fake_connect) with pytest.raises(exctype): - psycopg.Connection.connect(*args, **kwargs) + conn_cls.connect(*args, **kwargs) @pytest.mark.crdb("skip", reason="pg_terminate_backend") @@ -502,13 +502,13 @@ def test_execute_binary(conn): assert cur.pgresult.fformat(0) == 1 -def test_row_factory(dsn): - defaultconn = Connection.connect(dsn) - assert defaultconn.row_factory is tuple_row # type: ignore[comparison-overlap] +def test_row_factory(conn_cls, dsn): + defaultconn = conn_cls.connect(dsn) + assert defaultconn.row_factory is tuple_row defaultconn.close() - conn = Connection.connect(dsn, row_factory=my_row_factory) - assert conn.row_factory is my_row_factory # type: ignore[comparison-overlap] + conn = conn_cls.connect(dsn, row_factory=my_row_factory) + assert conn.row_factory is my_row_factory cur = conn.execute("select 'a' as ve") assert cur.fetchone() == ["Ave"] @@ -522,10 +522,10 @@ def test_row_factory(dsn): assert cur2.fetchall() == [(1, 1, 2)] # TODO: maybe fix something to get rid of 'type: ignore' below. - conn.row_factory = tuple_row # type: ignore[assignment] + conn.row_factory = tuple_row cur3 = conn.execute("select 'vale'") r = cur3.fetchone() - assert r and r == ("vale",) # type: ignore[comparison-overlap] + assert r and r == ("vale",) conn.close() @@ -556,11 +556,11 @@ def test_cursor_factory(conn): assert isinstance(cur, MyCursor) -def test_cursor_factory_connect(dsn): +def test_cursor_factory_connect(conn_cls, dsn): class MyCursor(psycopg.Cursor[psycopg.rows.Row]): pass - with psycopg.connect(dsn, cursor_factory=MyCursor) as conn: + with conn_cls.connect(dsn, cursor_factory=MyCursor) as conn: assert conn.cursor_factory is MyCursor cur = conn.cursor() assert type(cur) is MyCursor @@ -746,37 +746,37 @@ conninfo_params_timeout = [ @pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout) -def test_get_connection_params(dsn, kwargs, exp): - params = Connection._get_connection_params(dsn, **kwargs) +def test_get_connection_params(conn_cls, dsn, kwargs, exp): + params = conn_cls._get_connection_params(dsn, **kwargs) conninfo = make_conninfo(**params) assert conninfo_to_dict(conninfo) == exp[0] assert params.get("connect_timeout") == exp[1] -def test_connect_context(dsn): +def test_connect_context(conn_cls, dsn): ctx = psycopg.adapt.AdaptersMap(psycopg.adapters) ctx.register_dumper(str, make_bin_dumper("b")) ctx.register_dumper(str, make_dumper("t")) - conn = psycopg.connect(dsn, context=ctx) + conn = conn_cls.connect(dsn, context=ctx) cur = conn.execute("select %s", ["hello"]) - assert cur.fetchone()[0] == "hellot" # type: ignore[index] + assert cur.fetchone()[0] == "hellot" cur = conn.execute("select %b", ["hello"]) - assert cur.fetchone()[0] == "hellob" # type: ignore[index] + assert cur.fetchone()[0] == "hellob" conn.close() -def test_connect_context_copy(dsn, conn): +def test_connect_context_copy(conn_cls, dsn, conn): conn.adapters.register_dumper(str, make_bin_dumper("b")) conn.adapters.register_dumper(str, make_dumper("t")) - conn2 = psycopg.connect(dsn, context=conn) + conn2 = conn_cls.connect(dsn, context=conn) cur = conn2.execute("select %s", ["hello"]) - assert cur.fetchone()[0] == "hellot" # type: ignore[index] + assert cur.fetchone()[0] == "hellot" cur = conn2.execute("select %b", ["hello"]) - assert cur.fetchone()[0] == "hellob" # type: ignore[index] + assert cur.fetchone()[0] == "hellob" conn2.close() diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index bee4fb417..5f3c832b0 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -5,7 +5,7 @@ import weakref from typing import List, Any import psycopg -from psycopg import AsyncConnection, Notify, errors as e +from psycopg import Notify, errors as e from psycopg.rows import tuple_row from psycopg.conninfo import conninfo_to_dict, make_conninfo @@ -20,23 +20,23 @@ from .test_conninfo import fake_resolve # noqa: F401 pytestmark = pytest.mark.asyncio -async def test_connect(dsn): - conn = await AsyncConnection.connect(dsn) +async def test_connect(aconn_cls, dsn): + conn = await aconn_cls.connect(dsn) assert not conn.closed assert conn.pgconn.status == conn.ConnStatus.OK await conn.close() -async def test_connect_bad(): +async def test_connect_bad(aconn_cls): with pytest.raises(psycopg.OperationalError): - await AsyncConnection.connect("dbname=nosuchdb") + await aconn_cls.connect("dbname=nosuchdb") -async def test_connect_str_subclass(dsn): +async def test_connect_str_subclass(aconn_cls, dsn): class MyString(str): pass - conn = await AsyncConnection.connect(MyString(dsn)) + conn = await aconn_cls.connect(MyString(dsn)) assert not conn.closed assert conn.pgconn.status == conn.ConnStatus.OK await conn.close() @@ -44,12 +44,10 @@ async def test_connect_str_subclass(dsn): @pytest.mark.slow @pytest.mark.timing -async def test_connect_timeout(deaf_port): +async def test_connect_timeout(aconn_cls, deaf_port): t0 = time.time() with pytest.raises(psycopg.OperationalError, match="timeout expired"): - await AsyncConnection.connect( - host="localhost", port=deaf_port, connect_timeout=1 - ) + await aconn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1) elapsed = time.time() - t0 assert elapsed == pytest.approx(1.0, abs=0.05) @@ -96,22 +94,22 @@ async def test_cursor_closed(aconn): aconn.cursor() -async def test_connection_warn_close(dsn, recwarn): - conn = await AsyncConnection.connect(dsn) +async def test_connection_warn_close(aconn_cls, dsn, recwarn): + conn = await aconn_cls.connect(dsn) await conn.close() del conn assert not recwarn, [str(w.message) for w in recwarn.list] - conn = await AsyncConnection.connect(dsn) + conn = await aconn_cls.connect(dsn) del conn assert "IDLE" in str(recwarn.pop(ResourceWarning).message) - conn = await AsyncConnection.connect(dsn) + conn = await aconn_cls.connect(dsn) await conn.execute("select 1") del conn assert "INTRANS" in str(recwarn.pop(ResourceWarning).message) - conn = await AsyncConnection.connect(dsn) + conn = await aconn_cls.connect(dsn) try: await conn.execute("select wat") except Exception: @@ -119,14 +117,14 @@ async def test_connection_warn_close(dsn, recwarn): del conn assert "INERROR" in str(recwarn.pop(ResourceWarning).message) - async with await AsyncConnection.connect(dsn) as conn: + async with await aconn_cls.connect(dsn) as conn: pass del conn assert not recwarn, [str(w.message) for w in recwarn.list] @pytest.mark.usefixtures("testctx") -async def test_context_commit(aconn, dsn): +async def test_context_commit(aconn_cls, aconn, dsn): async with aconn: async with aconn.cursor() as cur: await cur.execute("insert into testctx values (42)") @@ -134,14 +132,14 @@ async def test_context_commit(aconn, dsn): assert aconn.closed assert not aconn.broken - async with await psycopg.AsyncConnection.connect(dsn) as aconn: + async with await aconn_cls.connect(dsn) as aconn: async with aconn.cursor() as cur: await cur.execute("select * from testctx") assert await cur.fetchall() == [(42,)] @pytest.mark.usefixtures("testctx") -async def test_context_rollback(aconn, dsn): +async def test_context_rollback(aconn_cls, aconn, dsn): with pytest.raises(ZeroDivisionError): async with aconn: async with aconn.cursor() as cur: @@ -151,7 +149,7 @@ async def test_context_rollback(aconn, dsn): assert aconn.closed assert not aconn.broken - async with await psycopg.AsyncConnection.connect(dsn) as aconn: + async with await aconn_cls.connect(dsn) as aconn: async with aconn.cursor() as cur: await cur.execute("select * from testctx") assert await cur.fetchall() == [] @@ -164,9 +162,9 @@ async def test_context_close(aconn): @pytest.mark.crdb("skip", reason="pg_terminate_backend") -async def test_context_inerror_rollback_no_clobber(conn, dsn, caplog): +async def test_context_inerror_rollback_no_clobber(aconn_cls, conn, dsn, caplog): with pytest.raises(ZeroDivisionError): - async with await psycopg.AsyncConnection.connect(dsn) as conn2: + async with await aconn_cls.connect(dsn) as conn2: await conn2.execute("select 1") conn.execute( "select pg_terminate_backend(%s::int)", @@ -181,11 +179,11 @@ async def test_context_inerror_rollback_no_clobber(conn, dsn, caplog): @pytest.mark.crdb("skip", reason="copy") -async def test_context_active_rollback_no_clobber(dsn, caplog): +async def test_context_active_rollback_no_clobber(aconn_cls, dsn, caplog): caplog.set_level(logging.WARNING, logger="psycopg") with pytest.raises(ZeroDivisionError): - async with await psycopg.AsyncConnection.connect(dsn) as conn: + async with await aconn_cls.connect(dsn) as conn: conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout") assert not conn.pgconn.error_message status = conn.info.transaction_status @@ -199,8 +197,8 @@ async def test_context_active_rollback_no_clobber(dsn, caplog): @pytest.mark.slow -async def test_weakref(dsn): - conn = await psycopg.AsyncConnection.connect(dsn) +async def test_weakref(aconn_cls, dsn): + conn = await aconn_cls.connect(dsn) w = weakref.ref(conn) await conn.close() del conn @@ -317,8 +315,8 @@ async def test_autocommit(aconn): assert aconn.autocommit is True -async def test_autocommit_connect(dsn): - aconn = await psycopg.AsyncConnection.connect(dsn, autocommit=True) +async def test_autocommit_connect(aconn_cls, dsn): + aconn = await aconn_cls.connect(dsn, autocommit=True) assert aconn.autocommit await aconn.close() @@ -366,7 +364,7 @@ async def test_autocommit_unknown(aconn): (("dbname=foo",), {"user": None}, "dbname=foo"), ], ) -async def test_connect_args(monkeypatch, pgconn, args, kwargs, want): +async def test_connect_args(aconn_cls, monkeypatch, pgconn, args, kwargs, want): the_conninfo: str def fake_connect(conninfo): @@ -376,7 +374,7 @@ async def test_connect_args(monkeypatch, pgconn, args, kwargs, want): yield monkeypatch.setattr(psycopg.connection, "connect", fake_connect) - conn = await psycopg.AsyncConnection.connect(*args, **kwargs) + conn = await aconn_cls.connect(*args, **kwargs) assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) await conn.close() @@ -389,14 +387,14 @@ async def test_connect_args(monkeypatch, pgconn, args, kwargs, want): ((), {"nosuchparam": 42}, psycopg.ProgrammingError), ], ) -async def test_connect_badargs(monkeypatch, pgconn, args, kwargs, exctype): +async def test_connect_badargs(aconn_cls, monkeypatch, pgconn, args, kwargs, exctype): def fake_connect(conninfo): return pgconn yield monkeypatch.setattr(psycopg.connection, "connect", fake_connect) with pytest.raises(exctype): - await psycopg.AsyncConnection.connect(*args, **kwargs) + await aconn_cls.connect(*args, **kwargs) @pytest.mark.crdb("skip", reason="pg_terminate_backend") @@ -511,13 +509,13 @@ async def test_execute_binary(aconn): assert cur.pgresult.fformat(0) == 1 -async def test_row_factory(dsn): - defaultconn = await AsyncConnection.connect(dsn) - assert defaultconn.row_factory is tuple_row # type: ignore[comparison-overlap] +async def test_row_factory(aconn_cls, dsn): + defaultconn = await aconn_cls.connect(dsn) + assert defaultconn.row_factory is tuple_row await defaultconn.close() - conn = await AsyncConnection.connect(dsn, row_factory=my_row_factory) - assert conn.row_factory is my_row_factory # type: ignore[comparison-overlap] + conn = await aconn_cls.connect(dsn, row_factory=my_row_factory) + assert conn.row_factory is my_row_factory cur = await conn.execute("select 'a' as ve") assert await cur.fetchone() == ["Ave"] @@ -531,10 +529,10 @@ async def test_row_factory(dsn): assert await cur2.fetchall() == [(1, 1, 2)] # TODO: maybe fix something to get rid of 'type: ignore' below. - conn.row_factory = tuple_row # type: ignore[assignment] + conn.row_factory = tuple_row cur3 = await conn.execute("select 'vale'") r = await cur3.fetchone() - assert r and r == ("vale",) # type: ignore[comparison-overlap] + assert r and r == ("vale",) await conn.close() @@ -565,13 +563,11 @@ async def test_cursor_factory(aconn): assert isinstance(cur, MyCursor) -async def test_cursor_factory_connect(dsn): +async def test_cursor_factory_connect(aconn_cls, dsn): class MyCursor(psycopg.AsyncCursor[psycopg.rows.Row]): pass - async with await psycopg.AsyncConnection.connect( - dsn, cursor_factory=MyCursor - ) as conn: + async with await aconn_cls.connect(dsn, cursor_factory=MyCursor) as conn: assert conn.cursor_factory is MyCursor cur = conn.cursor() assert type(cur) is MyCursor @@ -695,37 +691,37 @@ async def test_set_transaction_param_strange(aconn): @pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout) -async def test_get_connection_params(dsn, kwargs, exp): - params = await AsyncConnection._get_connection_params(dsn, **kwargs) +async def test_get_connection_params(aconn_cls, dsn, kwargs, exp): + params = await aconn_cls._get_connection_params(dsn, **kwargs) conninfo = make_conninfo(**params) assert conninfo_to_dict(conninfo) == exp[0] assert params["connect_timeout"] == exp[1] -async def test_connect_context_adapters(dsn): +async def test_connect_context_adapters(aconn_cls, dsn): ctx = psycopg.adapt.AdaptersMap(psycopg.adapters) ctx.register_dumper(str, make_bin_dumper("b")) ctx.register_dumper(str, make_dumper("t")) - conn = await psycopg.AsyncConnection.connect(dsn, context=ctx) + conn = await aconn_cls.connect(dsn, context=ctx) cur = await conn.execute("select %s", ["hello"]) - assert (await cur.fetchone())[0] == "hellot" # type: ignore[index] + assert (await cur.fetchone())[0] == "hellot" cur = await conn.execute("select %b", ["hello"]) - assert (await cur.fetchone())[0] == "hellob" # type: ignore[index] + assert (await cur.fetchone())[0] == "hellob" await conn.close() -async def test_connect_context_copy(dsn, aconn): +async def test_connect_context_copy(aconn_cls, dsn, aconn): aconn.adapters.register_dumper(str, make_bin_dumper("b")) aconn.adapters.register_dumper(str, make_dumper("t")) - aconn2 = await psycopg.AsyncConnection.connect(dsn, context=aconn) + aconn2 = await aconn_cls.connect(dsn, context=aconn) cur = await aconn2.execute("select %s", ["hello"]) - assert (await cur.fetchone())[0] == "hellot" # type: ignore[index] + assert (await cur.fetchone())[0] == "hellot" cur = await aconn2.execute("select %b", ["hello"]) - assert (await cur.fetchone())[0] == "hellob" # type: ignore[index] + assert (await cur.fetchone())[0] == "hellob" await aconn2.close() diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py index 70571d286..e9d959fb8 100644 --- a/tests/test_conninfo.py +++ b/tests/test_conninfo.py @@ -10,7 +10,6 @@ from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo from psycopg.conninfo import resolve_hostaddr_async from psycopg._encodings import pg2pyenc -from .fix_db import connection_class from .fix_crdb import crdb_encoding snowman = "\u2603" @@ -144,28 +143,28 @@ class TestConnectionInfo: if k != "password": assert f"{k}=" in dsn - def test_get_params_env(self, dsn, monkeypatch): + def test_get_params_env(self, conn_cls, dsn, monkeypatch): dsn = conninfo_to_dict(dsn) dsn.pop("application_name", None) monkeypatch.delenv("PGAPPNAME", raising=False) - with connection_class().connect(**dsn) as conn: + with conn_cls.connect(**dsn) as conn: assert "application_name" not in conn.info.get_parameters() monkeypatch.setenv("PGAPPNAME", "hello test") - with connection_class().connect(**dsn) as conn: + with conn_cls.connect(**dsn) as conn: assert conn.info.get_parameters()["application_name"] == "hello test" - def test_dsn_env(self, dsn, monkeypatch): + def test_dsn_env(self, conn_cls, dsn, monkeypatch): dsn = conninfo_to_dict(dsn) dsn.pop("application_name", None) monkeypatch.delenv("PGAPPNAME", raising=False) - with connection_class().connect(**dsn) as conn: + with conn_cls.connect(**dsn) as conn: assert "application_name=" not in conn.info.dsn monkeypatch.setenv("PGAPPNAME", "hello test") - with connection_class().connect(**dsn) as conn: + with conn_cls.connect(**dsn) as conn: assert "application_name='hello test'" in conn.info.dsn def test_status(self, conn): @@ -295,9 +294,9 @@ class TestConnectionInfo: crdb_encoding("euc-jp", "EUC_JP", "euc_jp"), ], ) - def test_encoding_env_var(self, dsn, monkeypatch, enc, out, codec): + def test_encoding_env_var(self, conn_cls, dsn, monkeypatch, enc, out, codec): monkeypatch.setenv("PGCLIENTENCODING", enc) - with connection_class().connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: clienc = conn.info.parameter_status("client_encoding") assert clienc if conn.info.vendor == "PostgreSQL": diff --git a/tests/test_copy.py b/tests/test_copy.py index 73e88d7d2..5eb8174d8 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -620,13 +620,13 @@ def test_worker_error_propagated(conn, monkeypatch): [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], ) @pytest.mark.parametrize("method", ["read", "iter", "row", "rows"]) -def test_copy_to_leaks(dsn, faker, fmt, set_types, method): +def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method): faker.format = PyFormat.from_pq(fmt) faker.choose_schema(ncols=20) faker.make_records(20) def work(): - with psycopg.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: with conn.cursor(binary=fmt) as cur: cur.execute(faker.drop_stmt) cur.execute(faker.create_stmt) @@ -654,7 +654,7 @@ def test_copy_to_leaks(dsn, faker, fmt, set_types, method): list(copy) elif method == "row": while True: - tmp = copy.read_row() # type: ignore[assignment] + tmp = copy.read_row() if tmp is None: break elif method == "rows": @@ -675,13 +675,13 @@ def test_copy_to_leaks(dsn, faker, fmt, set_types, method): "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], ) -def test_copy_from_leaks(dsn, faker, fmt, set_types): +def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types): faker.format = PyFormat.from_pq(fmt) faker.choose_schema(ncols=20) faker.make_records(20) def work(): - with psycopg.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: with conn.cursor(binary=fmt) as cur: cur.execute(faker.drop_stmt) cur.execute(faker.create_stmt) @@ -715,11 +715,11 @@ def test_copy_from_leaks(dsn, faker, fmt, set_types): @pytest.mark.slow @pytest.mark.parametrize("mode", ["row", "block", "binary"]) -def test_copy_table_across(dsn, faker, mode): +def test_copy_table_across(conn_cls, dsn, faker, mode): faker.choose_schema(ncols=20) faker.make_records(20) - with psycopg.connect(dsn) as conn1, psycopg.connect(dsn) as conn2: + with conn_cls.connect(dsn) as conn1, conn_cls.connect(dsn) as conn2: faker.table_name = sql.Identifier("copy_src") conn1.execute(faker.drop_stmt) conn1.execute(faker.create_stmt) diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 12ae52d83..46f905df3 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -625,13 +625,13 @@ async def test_worker_error_propagated(aconn, monkeypatch): [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], ) @pytest.mark.parametrize("method", ["read", "iter", "row", "rows"]) -async def test_copy_to_leaks(dsn, faker, fmt, set_types, method): +async def test_copy_to_leaks(aconn_cls, dsn, faker, fmt, set_types, method): faker.format = PyFormat.from_pq(fmt) faker.choose_schema(ncols=20) faker.make_records(20) async def work(): - async with await psycopg.AsyncConnection.connect(dsn) as conn: + async with await aconn_cls.connect(dsn) as conn: async with conn.cursor(binary=fmt) as cur: await cur.execute(faker.drop_stmt) await cur.execute(faker.create_stmt) @@ -659,7 +659,7 @@ async def test_copy_to_leaks(dsn, faker, fmt, set_types, method): await alist(copy) elif method == "row": while True: - tmp = await copy.read_row() # type: ignore[assignment] + tmp = await copy.read_row() if tmp is None: break elif method == "rows": @@ -680,13 +680,13 @@ async def test_copy_to_leaks(dsn, faker, fmt, set_types, method): "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], ) -async def test_copy_from_leaks(dsn, faker, fmt, set_types): +async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types): faker.format = PyFormat.from_pq(fmt) faker.choose_schema(ncols=20) faker.make_records(20) async def work(): - async with await psycopg.AsyncConnection.connect(dsn) as conn: + async with await aconn_cls.connect(dsn) as conn: async with conn.cursor(binary=fmt) as cur: await cur.execute(faker.drop_stmt) await cur.execute(faker.create_stmt) @@ -720,11 +720,11 @@ async def test_copy_from_leaks(dsn, faker, fmt, set_types): @pytest.mark.slow @pytest.mark.parametrize("mode", ["row", "block", "binary"]) -async def test_copy_table_across(dsn, faker, mode): +async def test_copy_table_across(aconn_cls, dsn, faker, mode): faker.choose_schema(ncols=20) faker.make_records(20) - connect = psycopg.AsyncConnection.connect + connect = aconn_cls.connect async with await connect(dsn) as conn1, await connect(dsn) as conn2: faker.table_name = sql.Identifier("copy_src") await conn1.execute(faker.drop_stmt) diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 18b561c7a..374dde452 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -833,14 +833,14 @@ def test_str(conn): @pytest.mark.parametrize("fmt_out", pq.Format) @pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"]) @pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) -def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory): +def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory): faker.format = fmt faker.choose_schema(ncols=5) faker.make_records(10) row_factory = getattr(rows, row_factory) def work(): - with psycopg.connect(dsn) as conn, conn.transaction(force_rollback=True): + with conn_cls.connect(dsn) as conn, conn.transaction(force_rollback=True): with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur: cur.execute(faker.drop_stmt) cur.execute(faker.create_stmt) diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index c2519edbe..8d11c9e98 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -703,14 +703,14 @@ async def test_str(aconn): @pytest.mark.parametrize("fmt_out", pq.Format) @pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"]) @pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) -async def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory): +async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory): faker.format = fmt faker.choose_schema(ncols=5) faker.make_records(10) row_factory = getattr(rows, row_factory) async def work(): - async with await psycopg.AsyncConnection.connect(dsn) as conn, conn.transaction( + async with await aconn_cls.connect(dsn) as conn, conn.transaction( force_rollback=True ): async with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur: diff --git a/tests/test_errors.py b/tests/test_errors.py index cb9ea145f..1d8b78c75 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -270,18 +270,18 @@ def test_unknown_sqlstate(conn): assert pexc.sqlstate == code -def test_pgconn_error(): +def test_pgconn_error(conn_cls): with pytest.raises(psycopg.OperationalError) as excinfo: - psycopg.connect("dbname=nosuchdb") + conn_cls.connect("dbname=nosuchdb") exc = excinfo.value assert exc.pgconn assert exc.pgconn.db == b"nosuchdb" -def test_pgconn_error_pickle(): +def test_pgconn_error_pickle(conn_cls): with pytest.raises(psycopg.OperationalError) as excinfo: - psycopg.connect("dbname=nosuchdb") + conn_cls.connect("dbname=nosuchdb") exc = pickle.loads(pickle.dumps(excinfo.value)) assert exc.pgconn is None diff --git a/tests/test_prepared.py b/tests/test_prepared.py index dea238ca4..d7b4c7c1e 100644 --- a/tests/test_prepared.py +++ b/tests/test_prepared.py @@ -7,13 +7,12 @@ from decimal import Decimal import pytest -import psycopg from psycopg.rows import namedtuple_row @pytest.mark.parametrize("value", [None, 0, 3]) -def test_prepare_threshold_init(dsn, value): - with psycopg.connect(dsn, prepare_threshold=value) as conn: +def test_prepare_threshold_init(conn_cls, dsn, value): + with conn_cls.connect(dsn, prepare_threshold=value) as conn: assert conn.prepare_threshold == value diff --git a/tests/test_prepared_async.py b/tests/test_prepared_async.py index a40a169e4..cab7cd94d 100644 --- a/tests/test_prepared_async.py +++ b/tests/test_prepared_async.py @@ -7,17 +7,14 @@ from decimal import Decimal import pytest -import psycopg from psycopg.rows import namedtuple_row pytestmark = pytest.mark.asyncio @pytest.mark.parametrize("value", [None, 0, 3]) -async def test_prepare_threshold_init(dsn, value): - async with await psycopg.AsyncConnection.connect( - dsn, prepare_threshold=value - ) as conn: +async def test_prepare_threshold_init(aconn_cls, dsn, value): + async with await aconn_cls.connect(dsn, prepare_threshold=value) as conn: assert conn.prepare_threshold == value diff --git a/tests/test_tpc.py b/tests/test_tpc.py index e93049c54..fad02a8b1 100644 --- a/tests/test_tpc.py +++ b/tests/test_tpc.py @@ -57,7 +57,7 @@ class TestTPC: assert tpc.count_xacts() == 0 assert tpc.count_test_records() == 1 - def test_tpc_commit_recovered(self, conn, dsn, tpc): + def test_tpc_commit_recovered(self, conn_cls, conn, dsn, tpc): xid = conn.xid(1, "gtrid", "bqual") assert conn.info.transaction_status == TransactionStatus.IDLE @@ -74,7 +74,7 @@ class TestTPC: assert tpc.count_xacts() == 1 assert tpc.count_test_records() == 0 - with psycopg.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: xid = conn.xid(1, "gtrid", "bqual") conn.tpc_commit(xid) assert conn.info.transaction_status == TransactionStatus.IDLE @@ -121,7 +121,7 @@ class TestTPC: assert tpc.count_xacts() == 0 assert tpc.count_test_records() == 0 - def test_tpc_rollback_recovered(self, conn, dsn, tpc): + def test_tpc_rollback_recovered(self, conn_cls, conn, dsn, tpc): xid = conn.xid(1, "gtrid", "bqual") assert conn.info.transaction_status == TransactionStatus.IDLE @@ -138,7 +138,7 @@ class TestTPC: assert tpc.count_xacts() == 1 assert tpc.count_test_records() == 0 - with psycopg.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: xid = conn.xid(1, "gtrid", "bqual") conn.tpc_rollback(xid) assert conn.info.transaction_status == TransactionStatus.IDLE @@ -207,13 +207,13 @@ class TestTPC: (0x7FFFFFFF, "x" * 64, "y" * 64), ], ) - def test_xid_roundtrip(self, conn, dsn, tpc, fid, gtrid, bqual): + def test_xid_roundtrip(self, conn_cls, conn, dsn, tpc, fid, gtrid, bqual): xid = conn.xid(fid, gtrid, bqual) conn.tpc_begin(xid) conn.tpc_prepare() conn.close() - with psycopg.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: xids = [x for x in conn.tpc_recover() if x.database == conn.info.dbname] assert len(xids) == 1 @@ -232,12 +232,12 @@ class TestTPC: "x" * 199, # PostgreSQL's limit in transaction id length ], ) - def test_unparsed_roundtrip(self, conn, dsn, tpc, tid): + def test_unparsed_roundtrip(self, conn_cls, conn, dsn, tpc, tid): conn.tpc_begin(tid) conn.tpc_prepare() conn.close() - with psycopg.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: xids = [x for x in conn.tpc_recover() if x.database == conn.info.dbname] assert len(xids) == 1 @@ -248,19 +248,19 @@ class TestTPC: assert xid.gtrid == tid assert xid.bqual is None - def test_xid_unicode(self, conn, dsn, tpc): + def test_xid_unicode(self, conn_cls, conn, dsn, tpc): x1 = conn.xid(10, "uni", "code") conn.tpc_begin(x1) conn.tpc_prepare() conn.close() - with psycopg.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: xid = [x for x in conn.tpc_recover() if x.database == conn.info.dbname][0] assert 10 == xid.format_id assert "uni" == xid.gtrid assert "code" == xid.bqual - def test_xid_unicode_unparsed(self, conn, dsn, tpc): + def test_xid_unicode_unparsed(self, conn_cls, conn, dsn, tpc): # We don't expect people shooting snowmen as transaction ids, # so if something explodes in an encode error I don't mind. # Let's just check unicode is accepted as type. @@ -271,7 +271,7 @@ class TestTPC: conn.tpc_prepare() conn.close() - with psycopg.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: xid = [x for x in conn.tpc_recover() if x.database == conn.info.dbname][0] assert xid.format_id is None @@ -284,13 +284,13 @@ class TestTPC: with pytest.raises(psycopg.ProgrammingError): conn.cancel() - def test_tpc_recover_non_dbapi_connection(self, conn, dsn, tpc): + def test_tpc_recover_non_dbapi_connection(self, conn_cls, conn, dsn, tpc): conn.row_factory = psycopg.rows.dict_row conn.tpc_begin("dict-connection") conn.tpc_prepare() conn.close() - with psycopg.connect(dsn) as conn: + with conn_cls.connect(dsn) as conn: xids = conn.tpc_recover() xid = [x for x in xids if x.database == conn.info.dbname][0] diff --git a/tests/test_tpc_async.py b/tests/test_tpc_async.py index c08237627..fd31ee033 100644 --- a/tests/test_tpc_async.py +++ b/tests/test_tpc_async.py @@ -61,7 +61,7 @@ class TestTPC: assert tpc.count_xacts() == 0 assert tpc.count_test_records() == 1 - async def test_tpc_commit_recovered(self, aconn, dsn, tpc): + async def test_tpc_commit_recovered(self, aconn_cls, aconn, dsn, tpc): xid = aconn.xid(1, "gtrid", "bqual") assert aconn.info.transaction_status == TransactionStatus.IDLE @@ -78,7 +78,7 @@ class TestTPC: assert tpc.count_xacts() == 1 assert tpc.count_test_records() == 0 - async with await psycopg.AsyncConnection.connect(dsn) as aconn: + async with await aconn_cls.connect(dsn) as aconn: xid = aconn.xid(1, "gtrid", "bqual") await aconn.tpc_commit(xid) assert aconn.info.transaction_status == TransactionStatus.IDLE @@ -125,7 +125,7 @@ class TestTPC: assert tpc.count_xacts() == 0 assert tpc.count_test_records() == 0 - async def test_tpc_rollback_recovered(self, aconn, dsn, tpc): + async def test_tpc_rollback_recovered(self, aconn_cls, aconn, dsn, tpc): xid = aconn.xid(1, "gtrid", "bqual") assert aconn.info.transaction_status == TransactionStatus.IDLE @@ -142,7 +142,7 @@ class TestTPC: assert tpc.count_xacts() == 1 assert tpc.count_test_records() == 0 - async with await psycopg.AsyncConnection.connect(dsn) as aconn: + async with await aconn_cls.connect(dsn) as aconn: xid = aconn.xid(1, "gtrid", "bqual") await aconn.tpc_rollback(xid) assert aconn.info.transaction_status == TransactionStatus.IDLE @@ -211,13 +211,13 @@ class TestTPC: (0x7FFFFFFF, "x" * 64, "y" * 64), ], ) - async def test_xid_roundtrip(self, aconn, dsn, tpc, fid, gtrid, bqual): + async def test_xid_roundtrip(self, aconn_cls, aconn, dsn, tpc, fid, gtrid, bqual): xid = aconn.xid(fid, gtrid, bqual) await aconn.tpc_begin(xid) await aconn.tpc_prepare() await aconn.close() - async with await psycopg.AsyncConnection.connect(dsn) as aconn: + async with await aconn_cls.connect(dsn) as aconn: xids = [ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname ] @@ -237,12 +237,12 @@ class TestTPC: "x" * 199, # PostgreSQL's limit in transaction id length ], ) - async def test_unparsed_roundtrip(self, aconn, dsn, tpc, tid): + async def test_unparsed_roundtrip(self, aconn_cls, aconn, dsn, tpc, tid): await aconn.tpc_begin(tid) await aconn.tpc_prepare() await aconn.close() - async with await psycopg.AsyncConnection.connect(dsn) as aconn: + async with await aconn_cls.connect(dsn) as aconn: xids = [ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname ] @@ -254,13 +254,13 @@ class TestTPC: assert xid.gtrid == tid assert xid.bqual is None - async def test_xid_unicode(self, aconn, dsn, tpc): + async def test_xid_unicode(self, aconn_cls, aconn, dsn, tpc): x1 = aconn.xid(10, "uni", "code") await aconn.tpc_begin(x1) await aconn.tpc_prepare() await aconn.close() - async with await psycopg.AsyncConnection.connect(dsn) as aconn: + async with await aconn_cls.connect(dsn) as aconn: xid = [ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname ][0] @@ -269,7 +269,7 @@ class TestTPC: assert "uni" == xid.gtrid assert "code" == xid.bqual - async def test_xid_unicode_unparsed(self, aconn, dsn, tpc): + async def test_xid_unicode_unparsed(self, aconn_cls, aconn, dsn, tpc): # We don't expect people shooting snowmen as transaction ids, # so if something explodes in an encode error I don't mind. # Let's just check unicode is accepted as type. @@ -280,7 +280,7 @@ class TestTPC: await aconn.tpc_prepare() await aconn.close() - async with await psycopg.AsyncConnection.connect(dsn) as aconn: + async with await aconn_cls.connect(dsn) as aconn: xid = [ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname ][0] @@ -295,13 +295,13 @@ class TestTPC: with pytest.raises(psycopg.ProgrammingError): aconn.cancel() - async def test_tpc_recover_non_dbapi_connection(self, aconn, dsn, tpc): + async def test_tpc_recover_non_dbapi_connection(self, aconn_cls, aconn, dsn, tpc): aconn.row_factory = psycopg.rows.dict_row await aconn.tpc_begin("dict-connection") await aconn.tpc_prepare() await aconn.close() - async with await psycopg.AsyncConnection.connect(dsn) as aconn: + async with await aconn_cls.connect(dsn) as aconn: xids = await aconn.tpc_recover() xid = [x for x in xids if x.database == aconn.info.dbname][0] diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 1c03540aa..9cb962c3b 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -4,7 +4,9 @@ from threading import Thread, Event import pytest -from psycopg import Connection, ProgrammingError, Rollback +import psycopg +from psycopg import Rollback +from psycopg import errors as e # TODOCRDB: is this the expected behaviour? crdb_skip_external_observer = pytest.mark.crdb( @@ -29,7 +31,7 @@ def create_test_table(svcconn): def insert_row(conn, value): sql = "INSERT INTO test_table VALUES (%s)" - if isinstance(conn, Connection): + if isinstance(conn, psycopg.Connection): conn.cursor().execute(sql, (value,)) else: @@ -43,7 +45,7 @@ def insert_row(conn, value): def inserted(conn): """Return the values inserted in the test table.""" sql = "SELECT * FROM test_table" - if isinstance(conn, Connection): + if isinstance(conn, psycopg.Connection): rows = conn.cursor().execute(sql).fetchall() return set(v for (v,) in rows) else: @@ -147,7 +149,7 @@ def test_rollback_on_exception_exit(conn): @pytest.mark.crdb("skip", reason="pg_terminate_backend") -def test_context_inerror_rollback_no_clobber(conn, pipeline, dsn, caplog): +def test_context_inerror_rollback_no_clobber(conn_cls, conn, pipeline, dsn, caplog): if pipeline: # Only 'conn' is possibly in pipeline mode, but the transaction and # checks are on 'conn2'. @@ -155,7 +157,7 @@ def test_context_inerror_rollback_no_clobber(conn, pipeline, dsn, caplog): caplog.set_level(logging.WARNING, logger="psycopg") with pytest.raises(ZeroDivisionError): - with Connection.connect(dsn) as conn2: + with conn_cls.connect(dsn) as conn2: with conn2.transaction(): conn2.execute("select 1") conn.execute( @@ -171,10 +173,10 @@ def test_context_inerror_rollback_no_clobber(conn, pipeline, dsn, caplog): @pytest.mark.crdb("skip", reason="copy") -def test_context_active_rollback_no_clobber(dsn, caplog): +def test_context_active_rollback_no_clobber(conn_cls, dsn, caplog): caplog.set_level(logging.WARNING, logger="psycopg") - conn = Connection.connect(dsn) + conn = conn_cls.connect(dsn) try: with pytest.raises(ZeroDivisionError): with conn.transaction(): @@ -217,11 +219,11 @@ def test_prohibits_use_of_commit_rollback_autocommit(conn): conn.rollback() with conn.transaction(): - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): conn.autocommit = False - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): conn.commit() - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): conn.rollback() conn.autocommit = False @@ -710,10 +712,10 @@ def test_out_of_order_exit(conn, exit_error): t2 = conn.transaction() t2.__enter__() - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): t1.__exit__(*get_exc_info(exit_error)) - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): t2.__exit__(*get_exc_info(exit_error)) @@ -727,10 +729,10 @@ def test_out_of_order_implicit_begin(conn, exit_error): t2 = conn.transaction() t2.__enter__() - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): t1.__exit__(*get_exc_info(exit_error)) - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): t2.__exit__(*get_exc_info(exit_error)) @@ -743,10 +745,10 @@ def test_out_of_order_exit_same_name(conn, exit_error): t2 = conn.transaction("save") t2.__enter__() - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): t1.__exit__(*get_exc_info(exit_error)) - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): t2.__exit__(*get_exc_info(exit_error)) @@ -754,10 +756,10 @@ def test_out_of_order_exit_same_name(conn, exit_error): def test_concurrency(conn, what): conn.autocommit = True - e = [Event() for i in range(3)] + evs = [Event() for i in range(3)] def worker(unlock, wait_on): - with pytest.raises(ProgrammingError) as ex: + with pytest.raises(e.ProgrammingError) as ex: with conn.transaction(): unlock.set() wait_on.wait() @@ -780,15 +782,15 @@ def test_concurrency(conn, what): assert "transaction commit" in str(ex.value) # Start a first transaction in a thread - t1 = Thread(target=worker, kwargs={"unlock": e[0], "wait_on": e[1]}) + t1 = Thread(target=worker, kwargs={"unlock": evs[0], "wait_on": evs[1]}) t1.start() - e[0].wait() + evs[0].wait() # Start a nested transaction in a thread - t2 = Thread(target=worker, kwargs={"unlock": e[1], "wait_on": e[2]}) + t2 = Thread(target=worker, kwargs={"unlock": evs[1], "wait_on": evs[2]}) t2.start() # Terminate the first transaction before the second does t1.join() - e[2].set() + evs[2].set() t2.join() diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py index cdaba9df3..6739d8b73 100644 --- a/tests/test_transaction_async.py +++ b/tests/test_transaction_async.py @@ -3,7 +3,8 @@ import logging import pytest -from psycopg import AsyncConnection, ProgrammingError, Rollback +from psycopg import Rollback +from psycopg import errors as e from psycopg._compat import create_task from .test_transaction import in_transaction, insert_row, inserted, get_exc_info @@ -85,7 +86,9 @@ async def test_rollback_on_exception_exit(aconn): @pytest.mark.crdb("skip", reason="pg_terminate_backend") -async def test_context_inerror_rollback_no_clobber(aconn, apipeline, dsn, caplog): +async def test_context_inerror_rollback_no_clobber( + aconn_cls, aconn, apipeline, dsn, caplog +): if apipeline: # Only 'aconn' is possibly in pipeline mode, but the transaction and # checks are on 'conn2'. @@ -93,7 +96,7 @@ async def test_context_inerror_rollback_no_clobber(aconn, apipeline, dsn, caplog caplog.set_level(logging.WARNING, logger="psycopg") with pytest.raises(ZeroDivisionError): - async with await AsyncConnection.connect(dsn) as conn2: + async with await aconn_cls.connect(dsn) as conn2: async with conn2.transaction(): await conn2.execute("select 1") await aconn.execute( @@ -109,10 +112,10 @@ async def test_context_inerror_rollback_no_clobber(aconn, apipeline, dsn, caplog @pytest.mark.crdb("skip", reason="copy") -async def test_context_active_rollback_no_clobber(dsn, caplog): +async def test_context_active_rollback_no_clobber(aconn_cls, dsn, caplog): caplog.set_level(logging.WARNING, logger="psycopg") - conn = await AsyncConnection.connect(dsn) + conn = await aconn_cls.connect(dsn) try: with pytest.raises(ZeroDivisionError): async with conn.transaction(): @@ -155,11 +158,11 @@ async def test_prohibits_use_of_commit_rollback_autocommit(aconn): await aconn.rollback() async with aconn.transaction(): - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): await aconn.set_autocommit(False) - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): await aconn.commit() - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): await aconn.rollback() await aconn.set_autocommit(False) @@ -658,10 +661,10 @@ async def test_out_of_order_exit(aconn, exit_error): t2 = aconn.transaction() await t2.__aenter__() - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): await t1.__aexit__(*get_exc_info(exit_error)) - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): await t2.__aexit__(*get_exc_info(exit_error)) @@ -675,10 +678,10 @@ async def test_out_of_order_implicit_begin(aconn, exit_error): t2 = aconn.transaction() await t2.__aenter__() - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): await t1.__aexit__(*get_exc_info(exit_error)) - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): await t2.__aexit__(*get_exc_info(exit_error)) @@ -691,10 +694,10 @@ async def test_out_of_order_exit_same_name(aconn, exit_error): t2 = aconn.transaction("save") await t2.__aenter__() - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): await t1.__aexit__(*get_exc_info(exit_error)) - with pytest.raises(ProgrammingError): + with pytest.raises(e.ProgrammingError): await t2.__aexit__(*get_exc_info(exit_error)) @@ -702,10 +705,10 @@ async def test_out_of_order_exit_same_name(aconn, exit_error): async def test_concurrency(aconn, what): await aconn.set_autocommit(True) - e = [asyncio.Event() for i in range(3)] + evs = [asyncio.Event() for i in range(3)] async def worker(unlock, wait_on): - with pytest.raises(ProgrammingError) as ex: + with pytest.raises(e.ProgrammingError) as ex: async with aconn.transaction(): unlock.set() await wait_on.wait() @@ -728,13 +731,13 @@ async def test_concurrency(aconn, what): assert "transaction commit" in str(ex.value) # Start a first transaction in a task - t1 = create_task(worker(unlock=e[0], wait_on=e[1])) - await e[0].wait() + t1 = create_task(worker(unlock=evs[0], wait_on=evs[1])) + await evs[0].wait() # Start a nested transaction in a task - t2 = create_task(worker(unlock=e[1], wait_on=e[2])) + t2 = create_task(worker(unlock=evs[1], wait_on=evs[2])) # Terminate the first transaction before the second does await asyncio.gather(t1) - e[2].set() + evs[2].set() await asyncio.gather(t2) diff --git a/tests/test_windows.py b/tests/test_windows.py index bb491cd16..09e61ba93 100644 --- a/tests/test_windows.py +++ b/tests/test_windows.py @@ -2,12 +2,11 @@ import pytest import asyncio import sys -import psycopg from psycopg.errors import InterfaceError @pytest.mark.skipif(sys.platform != "win32", reason="windows only test") -def test_windows_error(dsn): +def test_windows_error(aconn_cls, dsn): loop = asyncio.ProactorEventLoop() # type: ignore[attr-defined] async def go(): @@ -15,7 +14,7 @@ def test_windows_error(dsn): InterfaceError, match="Psycopg cannot use the 'ProactorEventLoop'", ): - await psycopg.AsyncConnection.connect(dsn) + await aconn_cls.connect(dsn) try: loop.run_until_complete(go()) diff --git a/tests/types/test_hstore.py b/tests/types/test_hstore.py index 1eac009ff..df67d17e5 100644 --- a/tests/types/test_hstore.py +++ b/tests/types/test_hstore.py @@ -66,13 +66,13 @@ def test_register_curs(hstore, conn): assert cur.fetchone() == (None, {}, {"a": "b"}) -def test_register_globally(hstore, dsn, svcconn, global_adapters): +def test_register_globally(conn_cls, hstore, dsn, svcconn, global_adapters): info = TypeInfo.fetch(svcconn, "hstore") register_hstore(info) assert psycopg.adapters.types[info.oid].name == "hstore" assert svcconn.adapters.types.get(info.oid) is None - conn = psycopg.connect(dsn) + conn = conn_cls.connect(dsn) assert conn.adapters.types[info.oid].name == "hstore" cur = conn.execute("select null::hstore, ''::hstore, 'a => b'::hstore")