From: Daniele Varrazzo Date: Tue, 24 May 2022 13:05:13 +0000 (+0200) Subject: test: fix skipping versions in parametrized tests X-Git-Tag: 3.1~49^2~50 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3579ca5024df206bcfe0578b33d48e3a15f35d1b;p=thirdparty%2Fpsycopg.git test: fix skipping versions in parametrized tests In the previous implementation, all the tests were skipped, because annotations were added to the function and ended up affecting all generated tests, not the marked ones only. --- diff --git a/tests/fix_crdb.py b/tests/fix_crdb.py index 15243fa92..dfa98ad9c 100644 --- a/tests/fix_crdb.py +++ b/tests/fix_crdb.py @@ -13,21 +13,7 @@ def pytest_configure(config): ) -def pytest_runtest_setup(item): - for m in item.iter_markers(name="crdb"): - if len(m.args) > 1: - raise TypeError("max one argument expected") - kwargs_unk = set(m.kwargs) - {"reason"} - if kwargs_unk: - raise TypeError(f"unknown keyword arguments: {kwargs_unk}") - - # Copy the want marker on the function so we can check the version - # after the connection has been created. - item.function.want_crdb = m.args[0] if m.args else "only" - item.function.crdb_reason = m.kwargs.get("reason") - - -def check_crdb_version(got, func): +def check_crdb_version(got, mark): """ Verify if the CockroachDB version is a version accepted. @@ -39,17 +25,19 @@ def check_crdb_version(got, func): and skips the test if the server version doesn't match what expected. """ - want = func.want_crdb + assert len(mark.args) <= 1 + assert not (set(mark.kwargs) - {"reason"}) + want = mark.args[0] if mark.args else "only" msg = None if got is None: if want == "only": - return "skipping test: CockroachDB only" + msg = "skipping test: CockroachDB only" else: if want == "only": pass elif want == "skip": - msg = crdb_skip_message(func.crdb_reason) + msg = crdb_skip_message(mark.kwargs.get("reason")) else: msg = check_version(got, want, "CockroachDB") diff --git a/tests/fix_db.py b/tests/fix_db.py index d5f9e4a1a..a069e9d95 100644 --- a/tests/fix_db.py +++ b/tests/fix_db.py @@ -67,14 +67,6 @@ def pytest_configure(config): ) -def pytest_runtest_setup(item): - # Copy the want marker on the function so we can check the version - # after the connection has been created. - want_ver = [m.args[0] for m in item.iter_markers() if m.name == "pg"] - if want_ver: - item.function.want_pg_version = want_ver[0] - - @pytest.fixture(scope="session") def session_dsn(request): """ @@ -98,7 +90,7 @@ def session_dsn(request): @pytest.fixture def dsn(session_dsn, request): """Return the dsn used to connect to the `--test-dsn` database.""" - check_connection_version(request.function) + check_connection_version(request.node) return session_dsn @@ -143,7 +135,7 @@ def maybe_trace(pgconn, tracefile, function): @pytest.fixture def pgconn(dsn, request, tracefile): """Return a PGconn connection open to `--test-dsn`.""" - check_connection_version(request.function) + check_connection_version(request.node) conn = pq.PGconn.connect(dsn.encode()) if conn.status != pq.ConnStatus.OK: @@ -158,7 +150,7 @@ def pgconn(dsn, request, tracefile): @pytest.fixture def conn(dsn, request, tracefile): """Return a `Connection` connected to the ``--test-dsn`` database.""" - check_connection_version(request.function) + check_connection_version(request.node) cls = psycopg.Connection if crdb_version: @@ -186,7 +178,7 @@ def pipeline(request, conn): @pytest.fixture async def aconn(dsn, request, tracefile): """Return an `AsyncConnection` connected to the ``--test-dsn`` database.""" - check_connection_version(request.function) + check_connection_version(request.node) cls = psycopg.AsyncConnection if crdb_version: @@ -263,26 +255,26 @@ class ListPopAll(list): # type: ignore[type-arg] return out -def check_connection_version(function): +def check_connection_version(node): try: pg_version except NameError: # First connection creation failed. Let the tests fail. - return None + pytest.fail("server version not available") - if hasattr(function, "want_pg_version"): - msg = check_server_version(pg_version, function) - if msg: - pytest.skip(msg) + for mark in node.iter_markers(): + if mark.name == "pg": + assert len(mark.args) == 1 + msg = check_server_version(pg_version, mark.args[0]) + if msg: + pytest.skip(msg) - if hasattr(function, "want_crdb"): - from .fix_crdb import check_crdb_version - - msg = check_crdb_version(crdb_version, function) - if msg: - pytest.skip(msg) + elif mark.name == "crdb": + from .fix_crdb import check_crdb_version - return None + msg = check_crdb_version(crdb_version, mark) + if msg: + pytest.skip(msg) @pytest.fixture diff --git a/tests/utils.py b/tests/utils.py index 1b1dfa88e..f472d796c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,7 +20,7 @@ def check_libpq_version(got, want): return check_version(got, want, "libpq") -def check_server_version(got, function): +def check_server_version(got, want): """ Verify if the server version is a version accepted. @@ -30,7 +30,6 @@ def check_server_version(got, function): and skips the test if the server version doesn't match what expected. """ - want = function.want_pg_version return check_version(got, want, "server")