+from typing import Optional
+
import pytest
-from .utils import check_version
+from .utils import VersionCheck
from psycopg.crdb import CrdbConnection
# register libpq marker
config.addinivalue_line(
"markers",
- "crdb(version_expr, reason=detail): run the test only with matching CockroachDB"
- " (e.g. '>= 21.2.10', '< 22.1', 'skip')",
+ "crdb(version_expr, reason=detail): run/skip the test with matching CockroachDB"
+ " (e.g. '>= 21.2.10', '< 22.1', 'skip < 22')",
)
This function is called on the tests marked with something like::
- @pytest.mark.crdb(">= 21.1")
- @pytest.mark.crdb("only")
- @pytest.mark.crdb("skip")
+ @pytest.mark.crdb("only") # run on CRDB only, any version
+ @pytest.mark.crdb # same as above
+ @pytest.mark.crdb("only >= 21.1") # run on CRDB only >= 21.1 (not on PG)
+ @pytest.mark.crdb(">= 21.1") # same as above
+ @pytest.mark.crdb("skip") # don't run on CRDB, any version
+ @pytest.mark.crdb("skip < 22") # don't run on CRDB < 22 (run on PG)
and skips the test if the server version doesn't match what expected.
"""
assert len(mark.args) <= 1
assert not (set(mark.kwargs) - {"reason"})
- want = mark.args[0] if mark.args else "only"
- msg = None
-
- if got is None:
- if want == "only":
- msg = "skipping test: CockroachDB only"
- else:
- if want == "only":
- pass
- elif want == "skip":
- msg = crdb_skip_message(mark.kwargs.get("reason"))
- else:
- msg = check_version(got, want, "CockroachDB")
+ pred = VersionCheck.parse(mark.args[0] if mark.args else "only")
+ pred.whose = "CockroachDB"
+
+ msg = pred.get_skip_message(got)
+ if not msg:
+ return None
+
+ reason = crdb_skip_message(mark.kwargs.get("reason"))
+ if reason:
+ msg = f"{msg}: {reason}"
return msg
is_crdb = CrdbConnection.is_crdb
-def crdb_skip_message(reason):
- msg = "skipping test on CockroachDB"
+def crdb_skip_message(reason: Optional[str]) -> str:
+ msg = ""
if reason:
- msg = f"{msg}: {reason}"
+ msg = reason
if reason in _crdb_reasons:
url = (
"https://github.com/cockroachdb/cockroach/"
import gc
import re
import operator
+from typing import Callable, Optional, Tuple
import pytest
return check_version(got, want, "libpq")
-def check_server_version(got, want):
+def check_postgres_version(got, want):
"""
Verify if the server version is a version accepted.
and skips the test if the server version doesn't match what expected.
"""
- return check_version(got, want, "server")
+ return check_version(got, want, "PostgreSQL")
def check_version(got, want, whose_version):
- """Check that a postgres-style version matches a desired spec.
+ pred = VersionCheck.parse(want)
+ pred.whose = whose_version
+ return pred.get_skip_message(got)
- - The postgres-style version is a number such as 90603 for 9.6.3.
- - The want version is a spec string such as "> 9.6"
+
+class VersionCheck:
+ """
+ Helper to compare a version number with a test spec.
"""
- # convert 90603 to (9, 6, 3), 120003 to (12, 3)
- got, got_fix = divmod(got, 100)
- got_maj, got_min = divmod(got, 100)
- if got_maj >= 10:
- got = (got_maj, got_fix)
- else:
- got = (got_maj, got_min, got_fix)
-
- # Parse a spec like "> 9.6"
- m = re.match(
- r"^\s*(>=|<=|>|<|==)?\s*(?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?\s*$",
- want,
- )
- if m is None:
- pytest.fail(f"bad wanted version spec: {want}")
-
- # convert "9.6" into (9, 6, 0), "10.3" into (10, 3)
- want_maj = int(m.group(2))
- want_min = int(m.group(3) or "0")
- want_fix = int(m.group(4) or "0")
- if want_maj >= 10:
- if want_fix:
- pytest.fail(f"bad version in {want}")
- want = (want_maj, want_min)
- else:
- want = (want_maj, want_min, want_fix)
-
- opnames = {">=": "ge", "<=": "le", ">": "gt", "<": "lt", "==": "eq"}
- op = getattr(operator, opnames[m.group(1) or "=="])
-
- if not op(got, want):
- revops = {">=": "<", "<=": ">", ">": "<=", "<": ">=", "==": "!="}
- return (
- f"{whose_version} version is {'.'.join(map(str, got))}"
- f" {revops[m.group(1)]} {'.'.join(map(str, want))}"
+
+ def __init__(
+ self,
+ *,
+ skip: bool = False,
+ op: Optional[str] = None,
+ version_tuple: Tuple[int, ...] = (),
+ whose: str = "(wanted)",
+ ):
+ self.skip = skip
+ self.op = op or "=="
+ self.version_tuple = version_tuple
+ self.whose = whose
+
+ @classmethod
+ def parse(cls, spec: str) -> "VersionCheck":
+ # Parse a spec like "> 9.6", "skip < 21.2.0"
+ m = re.match(
+ r"""(?ix)
+ ^\s* (skip|only)?
+ \s* (>=|<=|>|<)?
+ \s* (?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?
+ \s* $
+ """,
+ spec,
)
+ if m is None:
+ pytest.fail(f"bad wanted version spec: {spec}")
+
+ skip = (m.group(1) or "only").lower() == "skip"
+ op = m.group(2)
+ version_tuple = tuple(int(n) for n in m.groups()[2:] if n)
+ return cls(skip=skip, op=op, version_tuple=version_tuple)
+
+ def get_skip_message(self, version: Optional[int]) -> Optional[str]:
+ got_tuple = self._parse_int_version(version)
+
+ msg: Optional[str] = None
+ if self.skip:
+ if got_tuple:
+ if not self.version_tuple:
+ msg = f"skip on {self.whose}"
+ elif self._match_version(got_tuple):
+ msg = (
+ f"skip on {self.whose} {self.op}"
+ f" {'.'.join(map(str, self.version_tuple))}"
+ )
+ else:
+ if not got_tuple:
+ msg = f"only for {self.whose}"
+ elif not self._match_version(got_tuple):
+ if self.version_tuple:
+ msg = (
+ f"only for {self.whose} {self.op}"
+ f" {'.'.join(map(str, self.version_tuple))}"
+ )
+ else:
+ msg = f"only for {self.whose}"
+
+ return msg
+
+ _OP_NAMES = {">=": "ge", "<=": "le", ">": "gt", "<": "lt", "==": "eq"}
+
+ def _match_version(self, got_tuple: Tuple[int, ...]) -> bool:
+ if not self.version_tuple:
+ return True
+ op: Callable[[Tuple[int, ...], Tuple[int, ...]], bool]
+ op = getattr(operator, self._OP_NAMES[self.op])
+ return op(got_tuple, self.version_tuple)
+
+ def _parse_int_version(self, version: Optional[int]) -> Tuple[int, ...]:
+ if version is None:
+ return ()
+ version, ver_fix = divmod(version, 100)
+ ver_maj, ver_min = divmod(version, 100)
+ return (ver_maj, ver_min, ver_fix)
def gc_collect():