]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test: fix skipping versions in parametrized tests
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 May 2022 13:05:13 +0000 (15:05 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jul 2022 11:58:34 +0000 (12:58 +0100)
In the previous implementation, all the tests were skipped, because
annotations were added to the function and ended up affecting all
generated tests, not the marked ones only.

tests/fix_crdb.py
tests/fix_db.py
tests/utils.py

index 15243fa921211ed68dadaa476a3061c012830f3b..dfa98ad9c35148b2a7457a439a0960bdcba5bf23 100644 (file)
@@ -13,21 +13,7 @@ def pytest_configure(config):
     )
 
 
-def pytest_runtest_setup(item):
-    for m in item.iter_markers(name="crdb"):
-        if len(m.args) > 1:
-            raise TypeError("max one argument expected")
-        kwargs_unk = set(m.kwargs) - {"reason"}
-        if kwargs_unk:
-            raise TypeError(f"unknown keyword arguments: {kwargs_unk}")
-
-        # Copy the want marker on the function so we can check the version
-        # after the connection has been created.
-        item.function.want_crdb = m.args[0] if m.args else "only"
-        item.function.crdb_reason = m.kwargs.get("reason")
-
-
-def check_crdb_version(got, func):
+def check_crdb_version(got, mark):
     """
     Verify if the CockroachDB version is a version accepted.
 
@@ -39,17 +25,19 @@ def check_crdb_version(got, func):
 
     and skips the test if the server version doesn't match what expected.
     """
-    want = func.want_crdb
+    assert len(mark.args) <= 1
+    assert not (set(mark.kwargs) - {"reason"})
+    want = mark.args[0] if mark.args else "only"
     msg = None
 
     if got is None:
         if want == "only":
-            return "skipping test: CockroachDB only"
+            msg = "skipping test: CockroachDB only"
     else:
         if want == "only":
             pass
         elif want == "skip":
-            msg = crdb_skip_message(func.crdb_reason)
+            msg = crdb_skip_message(mark.kwargs.get("reason"))
         else:
             msg = check_version(got, want, "CockroachDB")
 
index d5f9e4a1a8a57eeb73e5cf89f133e9d49f73d075..a069e9d95e0d8d75b06e7d6878b0a651736b6a37 100644 (file)
@@ -67,14 +67,6 @@ def pytest_configure(config):
     )
 
 
-def pytest_runtest_setup(item):
-    # Copy the want marker on the function so we can check the version
-    # after the connection has been created.
-    want_ver = [m.args[0] for m in item.iter_markers() if m.name == "pg"]
-    if want_ver:
-        item.function.want_pg_version = want_ver[0]
-
-
 @pytest.fixture(scope="session")
 def session_dsn(request):
     """
@@ -98,7 +90,7 @@ def session_dsn(request):
 @pytest.fixture
 def dsn(session_dsn, request):
     """Return the dsn used to connect to the `--test-dsn` database."""
-    check_connection_version(request.function)
+    check_connection_version(request.node)
     return session_dsn
 
 
@@ -143,7 +135,7 @@ def maybe_trace(pgconn, tracefile, function):
 @pytest.fixture
 def pgconn(dsn, request, tracefile):
     """Return a PGconn connection open to `--test-dsn`."""
-    check_connection_version(request.function)
+    check_connection_version(request.node)
 
     conn = pq.PGconn.connect(dsn.encode())
     if conn.status != pq.ConnStatus.OK:
@@ -158,7 +150,7 @@ def pgconn(dsn, request, tracefile):
 @pytest.fixture
 def conn(dsn, request, tracefile):
     """Return a `Connection` connected to the ``--test-dsn`` database."""
-    check_connection_version(request.function)
+    check_connection_version(request.node)
 
     cls = psycopg.Connection
     if crdb_version:
@@ -186,7 +178,7 @@ def pipeline(request, conn):
 @pytest.fixture
 async def aconn(dsn, request, tracefile):
     """Return an `AsyncConnection` connected to the ``--test-dsn`` database."""
-    check_connection_version(request.function)
+    check_connection_version(request.node)
 
     cls = psycopg.AsyncConnection
     if crdb_version:
@@ -263,26 +255,26 @@ class ListPopAll(list):  # type: ignore[type-arg]
         return out
 
 
-def check_connection_version(function):
+def check_connection_version(node):
     try:
         pg_version
     except NameError:
         # First connection creation failed. Let the tests fail.
-        return None
+        pytest.fail("server version not available")
 
-    if hasattr(function, "want_pg_version"):
-        msg = check_server_version(pg_version, function)
-        if msg:
-            pytest.skip(msg)
+    for mark in node.iter_markers():
+        if mark.name == "pg":
+            assert len(mark.args) == 1
+            msg = check_server_version(pg_version, mark.args[0])
+            if msg:
+                pytest.skip(msg)
 
-    if hasattr(function, "want_crdb"):
-        from .fix_crdb import check_crdb_version
-
-        msg = check_crdb_version(crdb_version, function)
-        if msg:
-            pytest.skip(msg)
+        elif mark.name == "crdb":
+            from .fix_crdb import check_crdb_version
 
-    return None
+            msg = check_crdb_version(crdb_version, mark)
+            if msg:
+                pytest.skip(msg)
 
 
 @pytest.fixture
index 1b1dfa88eb905369e017c2e4718a6eb027104eca..f472d796c4ebb2d119b68ee4b8a8d6ab2a4ee28e 100644 (file)
@@ -20,7 +20,7 @@ def check_libpq_version(got, want):
     return check_version(got, want, "libpq")
 
 
-def check_server_version(got, function):
+def check_server_version(got, want):
     """
     Verify if the server version is a version accepted.
 
@@ -30,7 +30,6 @@ def check_server_version(got, function):
 
     and skips the test if the server version doesn't match what expected.
     """
-    want = function.want_pg_version
     return check_version(got, want, "server")