]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add pytest marker to skip tests on certain server versions
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 4 May 2021 16:08:30 +0000 (18:08 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 4 May 2021 21:24:10 +0000 (23:24 +0200)
tests/fix_db.py
tests/fix_pq.py
tests/utils.py [new file with mode: 0644]

index 1cab75592dcdbb0b16999fa722363ef1f2932801..ee01caa4a2cd5abc45346b0f85be264ce1d57c35 100644 (file)
@@ -1,6 +1,8 @@
 import os
 import pytest
 
+from .utils import check_server_version
+
 
 def pytest_addoption(parser):
     parser.addoption(
@@ -12,6 +14,15 @@ def pytest_addoption(parser):
     )
 
 
+def pytest_configure(config):
+    # register pg marker
+    config.addinivalue_line(
+        "markers",
+        "pg(version_expr): run the test only with matching server version"
+        " (e.g. '>= 10', '< 9.6')",
+    )
+
+
 @pytest.fixture(scope="session")
 def dsn(request):
     """Return the dsn used to connect to the `--test-dsn` database."""
@@ -22,7 +33,7 @@ def dsn(request):
 
 
 @pytest.fixture
-def pgconn(dsn):
+def pgconn(dsn, request):
     """Return a PGconn connection open to `--test-dsn`."""
     from psycopg3 import pq
 
@@ -31,26 +42,38 @@ def pgconn(dsn):
         pytest.fail(
             f"bad connection: {conn.error_message.decode('utf8', 'replace')}"
         )
+    msg = check_connection_version(conn.server_version, request.function)
+    if msg:
+        conn.finish()
+        pytest.skip(msg)
     yield conn
     conn.finish()
 
 
 @pytest.fixture
-def conn(dsn):
+def conn(dsn, request):
     """Return a `Connection` connected to the ``--test-dsn`` database."""
     from psycopg3 import Connection
 
     conn = Connection.connect(dsn)
+    msg = check_connection_version(conn.info.server_version, request.function)
+    if msg:
+        conn.close()
+        pytest.skip(msg)
     yield conn
     conn.close()
 
 
 @pytest.fixture
-async def aconn(dsn):
+async def aconn(dsn, request):
     """Return an `AsyncConnection` connected to the ``--test-dsn`` database."""
     from psycopg3 import AsyncConnection
 
     conn = await AsyncConnection.connect(dsn)
+    msg = check_connection_version(conn.info.server_version, request.function)
+    if msg:
+        await conn.close()
+        pytest.skip(msg)
     yield conn
     await conn.close()
 
@@ -107,3 +130,11 @@ class ListPopAll(list):
         out = self[:]
         del self[:]
         return out
+
+
+def check_connection_version(got, function):
+    if not hasattr(function, "pytestmark"):
+        return
+    want = [m.args[0] for m in function.pytestmark if m.name == "pg"]
+    if want:
+        return check_server_version(got, want[0])
index bb5a66ff9e0e48f4e821fcb9292a581e85118c9d..115803f98c8481c2f16f69a584439f25b77734ce 100644 (file)
@@ -1,9 +1,9 @@
-import re
-import operator
 import sys
 
 import pytest
 
+from .utils import check_libpq_version
+
 
 def pytest_report_header(config):
     try:
@@ -30,7 +30,10 @@ def pytest_runtest_setup(item):
     from psycopg3 import pq
 
     for m in item.iter_markers(name="libpq"):
-        check_libpq_version(pq.version(), m.args)
+        assert len(m.args) == 1
+        msg = check_libpq_version(pq.version(), m.args[0])
+        if msg:
+            pytest.skip(msg)
 
 
 @pytest.fixture
@@ -53,55 +56,3 @@ def libpq():
             pytest.skip(f"can't load libpq for testing: {e}")
         else:
             raise
-
-
-def check_libpq_version(got, want):
-    """
-    Verify if the libpq version is a version accepted.
-
-    This function is called on the tests marked with something like::
-
-        @pytest.mark.libpq(">= 12")
-
-    and skips the test if the requested version doesn't match what's loaded.
-
-    """
-    # convert 90603 to (9, 6, 3), 120003 to (12, 3)
-    got, got_fix = divmod(got, 100)
-    got_maj, got_min = divmod(got, 100)
-    if got_maj >= 10:
-        got = (got_maj, got_fix)
-    else:
-        got = (got_maj, got_min, got_fix)
-
-    # Parse a spec like "> 9.6"
-    if len(want) != 1:
-        pytest.fail("libpq marker doesn't specify a version")
-    want = want[0]
-    m = re.match(
-        r"^\s*(>=|<=|>|<)\s*(?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?\s*$", want
-    )
-    if m is None:
-        pytest.fail(f"bad libpq spec: {want}")
-
-    # convert "9.6" into (9, 6, 0), "10.3" into (10, 3)
-    want_maj = int(m.group(2))
-    want_min = int(m.group(3) or "0")
-    want_fix = int(m.group(4) or "0")
-    if want_maj >= 10:
-        if want_fix:
-            pytest.fail(f"bad libpq version in {want}")
-        want = (want_maj, want_min)
-    else:
-        want = (want_maj, want_min, want_fix)
-
-    op = getattr(
-        operator, {">=": "ge", "<=": "le", ">": "gt", "<": "lt"}[m.group(1)]
-    )
-
-    if not op(got, want):
-        revops = {">=": "<", "<=": ">", ">": "<=", "<": ">="}
-        pytest.skip(
-            f"skipping test: libpq loaded is {'.'.join(map(str, got))}"
-            f" {revops[m.group(1)]} {'.'.join(map(str, want))}"
-        )
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644 (file)
index 0000000..e235d69
--- /dev/null
@@ -0,0 +1,69 @@
+import re
+import operator
+
+import pytest
+
+
+def check_libpq_version(got, want):
+    """
+    Verify if the libpq version is a version accepted.
+
+    This function is called on the tests marked with something like::
+
+        @pytest.mark.libpq(">= 12")
+
+    and skips the test if the requested version doesn't match what's loaded.
+    """
+    return _check_version(got, want, "libpq")
+
+
+def check_server_version(got, want):
+    """
+    Verify if the server version is a version accepted.
+
+    This function is called on the tests marked with something like::
+
+        @pytest.mark.pg(">= 12")
+
+    and skips the test if the server version doesn't match what expected.
+    """
+    return _check_version(got, want, "server")
+
+
+def _check_version(got, want, whose_version):
+    # convert 90603 to (9, 6, 3), 120003 to (12, 3)
+    got, got_fix = divmod(got, 100)
+    got_maj, got_min = divmod(got, 100)
+    if got_maj >= 10:
+        got = (got_maj, got_fix)
+    else:
+        got = (got_maj, got_min, got_fix)
+
+    # Parse a spec like "> 9.6"
+    m = re.match(
+        r"^\s*(>=|<=|>|<)\s*(?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?\s*$", want
+    )
+    if m is None:
+        pytest.fail(f"bad wanted version spec: {want}")
+
+    # convert "9.6" into (9, 6, 0), "10.3" into (10, 3)
+    want_maj = int(m.group(2))
+    want_min = int(m.group(3) or "0")
+    want_fix = int(m.group(4) or "0")
+    if want_maj >= 10:
+        if want_fix:
+            pytest.fail(f"bad version in {want}")
+        want = (want_maj, want_min)
+    else:
+        want = (want_maj, want_min, want_fix)
+
+    op = getattr(
+        operator, {">=": "ge", "<=": "le", ">": "gt", "<": "lt"}[m.group(1)]
+    )
+
+    if not op(got, want):
+        revops = {">=": "<", "<=": ">", ">": "<=", "<": ">="}
+        return (
+            f"skipping test: {whose_version} version is {'.'.join(map(str, got))}"
+            f" {revops[m.group(1)]} {'.'.join(map(str, want))}"
+        )