]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test: allow to mark tests to run only or skip on certain version ranges
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Jun 2022 20:10:36 +0000 (22:10 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jul 2022 11:58:34 +0000 (12:58 +0100)
tests/fix_crdb.py
tests/fix_db.py
tests/utils.py

index 8a13c28c9d2408f62197602cbb571d41ae904a64..3209bd4e6f3b3c3845a54a4f63f2ef4ece3e3292 100644 (file)
@@ -1,6 +1,8 @@
+from typing import Optional
+
 import pytest
 
-from .utils import check_version
+from .utils import VersionCheck
 from psycopg.crdb import CrdbConnection
 
 
@@ -8,8 +10,8 @@ def pytest_configure(config):
     # register libpq marker
     config.addinivalue_line(
         "markers",
-        "crdb(version_expr, reason=detail): run the test only with matching CockroachDB"
-        " (e.g. '>= 21.2.10', '< 22.1', 'skip')",
+        "crdb(version_expr, reason=detail): run/skip the test with matching CockroachDB"
+        " (e.g. '>= 21.2.10', '< 22.1', 'skip < 22')",
     )
 
 
@@ -19,27 +21,27 @@ def check_crdb_version(got, mark):
 
     This function is called on the tests marked with something like::
 
-        @pytest.mark.crdb(">= 21.1")
-        @pytest.mark.crdb("only")
-        @pytest.mark.crdb("skip")
+        @pytest.mark.crdb("only")           # run on CRDB only, any version
+        @pytest.mark.crdb                   # same as above
+        @pytest.mark.crdb("only >= 21.1")   # run on CRDB only >= 21.1 (not on PG)
+        @pytest.mark.crdb(">= 21.1")        # same as above
+        @pytest.mark.crdb("skip")           # don't run on CRDB, any version
+        @pytest.mark.crdb("skip < 22")      # don't run on CRDB < 22 (run on PG)
 
     and skips the test if the server version doesn't match what expected.
     """
     assert len(mark.args) <= 1
     assert not (set(mark.kwargs) - {"reason"})
-    want = mark.args[0] if mark.args else "only"
-    msg = None
-
-    if got is None:
-        if want == "only":
-            msg = "skipping test: CockroachDB only"
-    else:
-        if want == "only":
-            pass
-        elif want == "skip":
-            msg = crdb_skip_message(mark.kwargs.get("reason"))
-        else:
-            msg = check_version(got, want, "CockroachDB")
+    pred = VersionCheck.parse(mark.args[0] if mark.args else "only")
+    pred.whose = "CockroachDB"
+
+    msg = pred.get_skip_message(got)
+    if not msg:
+        return None
+
+    reason = crdb_skip_message(mark.kwargs.get("reason"))
+    if reason:
+        msg = f"{msg}: {reason}"
 
     return msg
 
@@ -49,10 +51,10 @@ def check_crdb_version(got, mark):
 is_crdb = CrdbConnection.is_crdb
 
 
-def crdb_skip_message(reason):
-    msg = "skipping test on CockroachDB"
+def crdb_skip_message(reason: Optional[str]) -> str:
+    msg = ""
     if reason:
-        msg = f"{msg}: {reason}"
+        msg = reason
         if reason in _crdb_reasons:
             url = (
                 "https://github.com/cockroachdb/cockroach/"
index c2c8527ee4207b7cb50a96be9ba3184efa67a6b2..49bd688d4fd399034346afcd3539838ee6f5e880 100644 (file)
@@ -8,7 +8,7 @@ import psycopg
 from psycopg import pq
 from psycopg import sql
 
-from .utils import check_libpq_version, check_server_version
+from .utils import check_libpq_version, check_postgres_version
 
 # Set by warm_up_database() the first time the dsn fixture is used
 pg_version: int
@@ -276,7 +276,7 @@ def check_connection_version(node):
     for mark in node.iter_markers():
         if mark.name == "pg":
             assert len(mark.args) == 1
-            msg = check_server_version(pg_version, mark.args[0])
+            msg = check_postgres_version(pg_version, mark.args[0])
             if msg:
                 pytest.skip(msg)
 
index f472d796c4ebb2d119b68ee4b8a8d6ab2a4ee28e..496b241cc17d7f4d54d927e9b5f4bfc29eb74d77 100644 (file)
@@ -1,6 +1,7 @@
 import gc
 import re
 import operator
+from typing import Callable, Optional, Tuple
 
 import pytest
 
@@ -20,7 +21,7 @@ def check_libpq_version(got, want):
     return check_version(got, want, "libpq")
 
 
-def check_server_version(got, want):
+def check_postgres_version(got, want):
     """
     Verify if the server version is a version accepted.
 
@@ -30,51 +31,95 @@ def check_server_version(got, want):
 
     and skips the test if the server version doesn't match what expected.
     """
-    return check_version(got, want, "server")
+    return check_version(got, want, "PostgreSQL")
 
 
 def check_version(got, want, whose_version):
-    """Check that a postgres-style version matches a desired spec.
+    pred = VersionCheck.parse(want)
+    pred.whose = whose_version
+    return pred.get_skip_message(got)
 
-    - The postgres-style version is a number such as 90603 for 9.6.3.
-    - The want version is a spec string such as "> 9.6"
+
+class VersionCheck:
+    """
+    Helper to compare a version number with a test spec.
     """
-    # convert 90603 to (9, 6, 3), 120003 to (12, 3)
-    got, got_fix = divmod(got, 100)
-    got_maj, got_min = divmod(got, 100)
-    if got_maj >= 10:
-        got = (got_maj, got_fix)
-    else:
-        got = (got_maj, got_min, got_fix)
-
-    # Parse a spec like "> 9.6"
-    m = re.match(
-        r"^\s*(>=|<=|>|<|==)?\s*(?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?\s*$",
-        want,
-    )
-    if m is None:
-        pytest.fail(f"bad wanted version spec: {want}")
-
-    # convert "9.6" into (9, 6, 0), "10.3" into (10, 3)
-    want_maj = int(m.group(2))
-    want_min = int(m.group(3) or "0")
-    want_fix = int(m.group(4) or "0")
-    if want_maj >= 10:
-        if want_fix:
-            pytest.fail(f"bad version in {want}")
-        want = (want_maj, want_min)
-    else:
-        want = (want_maj, want_min, want_fix)
-
-    opnames = {">=": "ge", "<=": "le", ">": "gt", "<": "lt", "==": "eq"}
-    op = getattr(operator, opnames[m.group(1) or "=="])
-
-    if not op(got, want):
-        revops = {">=": "<", "<=": ">", ">": "<=", "<": ">=", "==": "!="}
-        return (
-            f"{whose_version} version is {'.'.join(map(str, got))}"
-            f" {revops[m.group(1)]} {'.'.join(map(str, want))}"
+
+    def __init__(
+        self,
+        *,
+        skip: bool = False,
+        op: Optional[str] = None,
+        version_tuple: Tuple[int, ...] = (),
+        whose: str = "(wanted)",
+    ):
+        self.skip = skip
+        self.op = op or "=="
+        self.version_tuple = version_tuple
+        self.whose = whose
+
+    @classmethod
+    def parse(cls, spec: str) -> "VersionCheck":
+        # Parse a spec like "> 9.6", "skip < 21.2.0"
+        m = re.match(
+            r"""(?ix)
+            ^\s* (skip|only)?
+            \s* (>=|<=|>|<)?
+            \s* (?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?
+            \s* $
+            """,
+            spec,
         )
+        if m is None:
+            pytest.fail(f"bad wanted version spec: {spec}")
+
+        skip = (m.group(1) or "only").lower() == "skip"
+        op = m.group(2)
+        version_tuple = tuple(int(n) for n in m.groups()[2:] if n)
+        return cls(skip=skip, op=op, version_tuple=version_tuple)
+
+    def get_skip_message(self, version: Optional[int]) -> Optional[str]:
+        got_tuple = self._parse_int_version(version)
+
+        msg: Optional[str] = None
+        if self.skip:
+            if got_tuple:
+                if not self.version_tuple:
+                    msg = f"skip on {self.whose}"
+                elif self._match_version(got_tuple):
+                    msg = (
+                        f"skip on {self.whose} {self.op}"
+                        f" {'.'.join(map(str, self.version_tuple))}"
+                    )
+        else:
+            if not got_tuple:
+                msg = f"only for {self.whose}"
+            elif not self._match_version(got_tuple):
+                if self.version_tuple:
+                    msg = (
+                        f"only for {self.whose} {self.op}"
+                        f" {'.'.join(map(str, self.version_tuple))}"
+                    )
+                else:
+                    msg = f"only for {self.whose}"
+
+        return msg
+
+    _OP_NAMES = {">=": "ge", "<=": "le", ">": "gt", "<": "lt", "==": "eq"}
+
+    def _match_version(self, got_tuple: Tuple[int, ...]) -> bool:
+        if not self.version_tuple:
+            return True
+        op: Callable[[Tuple[int, ...], Tuple[int, ...]], bool]
+        op = getattr(operator, self._OP_NAMES[self.op])
+        return op(got_tuple, self.version_tuple)
+
+    def _parse_int_version(self, version: Optional[int]) -> Tuple[int, ...]:
+        if version is None:
+            return ()
+        version, ver_fix = divmod(version, 100)
+        ver_maj, ver_min = divmod(version, 100)
+        return (ver_maj, ver_min, ver_fix)
 
 
 def gc_collect():