]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test(crdb): add fixture to support CockroachDB test/skip
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 16 May 2022 02:59:04 +0000 (04:59 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jul 2022 11:58:33 +0000 (12:58 +0100)
tests/conftest.py
tests/fix_crdb.py [new file with mode: 0644]
tests/fix_db.py
tests/utils.py

index 71e574f0b2cf587a18edaec3c21b565680e2f65a..9d691b12dba91f634abc3f8926b8420aecdfcc2f 100644 (file)
@@ -11,6 +11,7 @@ pytest_plugins = (
     "tests.fix_faker",
     "tests.fix_proxy",
     "tests.fix_psycopg",
+    "tests.fix_crdb",
     "tests.pool.fix_pool",
 )
 
diff --git a/tests/fix_crdb.py b/tests/fix_crdb.py
new file mode 100644 (file)
index 0000000..e87e421
--- /dev/null
@@ -0,0 +1,120 @@
+import re
+
+from .utils import check_version
+
+
+def pytest_configure(config):
+    # register libpq marker
+    config.addinivalue_line(
+        "markers",
+        "crdb(version_expr, reason=detail): run the test only with matching CockroachDB"
+        " (e.g. '>= 21.2.10', '< 22.1', 'skip')",
+    )
+
+
+def pytest_runtest_setup(item):
+    for m in item.iter_markers(name="crdb"):
+        if len(m.args) > 1:
+            raise TypeError("max one argument expected")
+        kwargs_unk = set(m.kwargs) - {"reason"}
+        if kwargs_unk:
+            raise TypeError(f"unknown keyword arguments: {kwargs_unk}")
+
+        # Copy the want marker on the function so we can check the version
+        # after the connection has been created.
+        item.function.want_crdb = m.args[0] if m.args else "only"
+        item.function.crdb_reason = m.kwargs.get("reason")
+
+
+def check_crdb_version(pgconn, func):
+    """
+    Verify if the CockroachDB version is a version accepted.
+
+    This function is called on the tests marked with something like::
+
+        @pytest.mark.crdb(">= 21.1")
+        @pytest.mark.crdb("only")
+        @pytest.mark.crdb("skip")
+
+    and skips the test if the server version doesn't match what expected.
+    """
+    want = func.want_crdb
+    got = get_crdb_version(pgconn)
+    if got is None:
+        if want == "only":
+            return "skipping test: CockroachDB only"
+    else:
+        if want == "only":
+            rv = None
+        elif want == "skip":
+            rv = "skipping test: not supported on CockroachDB"
+        else:
+            rv = check_version(got, want, "CockroachDB")
+
+    if rv:
+        if func.crdb_reason:
+            rv = f"{rv}: {func.crdb_reason}"
+            if func.crdb_reason in crdb_reasons:
+                url = (
+                    "https://github.com/cockroachdb/cockroach/"
+                    f"issues/{crdb_reasons[func.crdb_reason]}"
+                )
+                rv = f"{rv} ({url})"
+        return rv
+
+
+def get_crdb_version(pgconn, __crdb_version=[]):
+    """
+    Return the CockroachDB server version connected.
+
+    Return None if the server is not CockroachDB, else return a number in
+    the PostgreSQL format (e.g. 21.2.10 -> 200210)
+
+    Assume all the connections are on the same db: return a cached result on
+    following calls.
+    """
+    if __crdb_version:
+        return __crdb_version[0]
+
+    bver = pgconn.parameter_status(b"crdb_version")
+    if bver:
+        sver = bver.decode()
+        m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
+        if not m:
+            raise ValueError(f"can't parse CockroachDB version from {sver}")
+
+        ver = int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))
+
+    else:
+        ver = None
+
+    __crdb_version.append(ver)
+
+    return __crdb_version[0]
+
+
+# mapping from reason description to ticket number
+crdb_reasons = {
+    "2-phase commit": 22329,
+    "backend pid": 35897,
+    "batch statements": 44803,
+    "cancel": 41335,
+    "cast adds tz": 51692,
+    "cidr": 18846,
+    "composite": 27792,
+    "copy": 41608,
+    "cursor with hold": 77101,
+    "deferrable": 48307,
+    "encoding": 35882,
+    "hstore": 41284,
+    "infinity date": 41564,
+    "interval style": 35807,
+    "large objects": 243,
+    "named cursor": 41412,
+    "nested array": 32552,
+    "notify": 41522,
+    "password_encryption": 42519,
+    "range": 41282,
+    "scroll cursor": 77102,
+    "stored procedure": 1751,
+}
index 9fd3111b008f050914b7c4bf3f38f5accaa53add..15f9540ce02020f4447015d9f93a8322a34153c5 100644 (file)
@@ -6,6 +6,7 @@ from typing import List
 
 import psycopg
 from psycopg import pq
+
 from .utils import check_libpq_version, check_server_version
 
 
@@ -132,7 +133,7 @@ def pgconn(dsn, request, tracefile):
     conn = pq.PGconn.connect(dsn.encode())
     if conn.status != pq.ConnStatus.OK:
         pytest.fail(f"bad connection: {conn.error_message.decode('utf8', 'replace')}")
-    msg = check_connection_version(conn.server_version, request.function)
+    msg = check_connection_version(conn, request.function)
     if msg:
         conn.finish()
         pytest.skip(msg)
@@ -149,7 +150,7 @@ def conn(dsn, request, tracefile):
     from psycopg import Connection
 
     conn = Connection.connect(dsn)
-    msg = check_connection_version(conn.info.server_version, request.function)
+    msg = check_connection_version(conn.pgconn, request.function)
     if msg:
         conn.close()
         pytest.skip(msg)
@@ -177,7 +178,7 @@ async def aconn(dsn, request, tracefile):
     from psycopg import AsyncConnection
 
     conn = await AsyncConnection.connect(dsn)
-    msg = check_connection_version(conn.info.server_version, request.function)
+    msg = check_connection_version(conn.pgconn, request.function)
     if msg:
         await conn.close()
         pytest.skip(msg)
@@ -253,10 +254,20 @@ class ListPopAll(list):  # type: ignore[type-arg]
         return out
 
 
-def check_connection_version(got, function):
-    if not hasattr(function, "want_pg_version"):
-        return
-    return check_server_version(got, function.want_pg_version)
+def check_connection_version(pgconn, function):
+    if hasattr(function, "want_pg_version"):
+        rv = check_server_version(pgconn, function.want_pg_version)
+        if rv:
+            return rv
+
+    if hasattr(function, "want_crdb"):
+        from .fix_crdb import check_crdb_version
+
+        rv = check_crdb_version(pgconn, function)
+        if rv:
+            return rv
+
+    return None
 
 
 @pytest.fixture
index ec1cba4466f9cdff8f832eb274b3d52265c73e4e..830eba6144ae2dd2fa76d79dbc039a3f90510f4d 100644 (file)
@@ -17,10 +17,10 @@ def check_libpq_version(got, want):
 
     and skips the test if the requested version doesn't match what's loaded.
     """
-    return _check_version(got, want, "libpq")
+    return check_version(got, want, "libpq")
 
 
-def check_server_version(got, want):
+def check_server_version(pgconn, want):
     """
     Verify if the server version is a version accepted.
 
@@ -30,10 +30,11 @@ def check_server_version(got, want):
 
     and skips the test if the server version doesn't match what expected.
     """
-    return _check_version(got, want, "server")
+    got = pgconn.server_version
+    return check_version(got, want, "server")
 
 
-def _check_version(got, want, whose_version):
+def check_version(got, want, whose_version):
     """Check that a postgres-style version matches a desired spec.
 
     - The postgres-style version is a number such as 90603 for 9.6.3.