From: Daniele Varrazzo Date: Mon, 16 May 2022 02:59:04 +0000 (+0200) Subject: test(crdb): add fixture to support CockroachDB test/skip X-Git-Tag: 3.1~49^2~66 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fdb3da095863bc74bbcbc2f69e0133074a2c9785;p=thirdparty%2Fpsycopg.git test(crdb): add fixture to support CockroachDB test/skip --- diff --git a/tests/conftest.py b/tests/conftest.py index 71e574f0b..9d691b12d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ pytest_plugins = ( "tests.fix_faker", "tests.fix_proxy", "tests.fix_psycopg", + "tests.fix_crdb", "tests.pool.fix_pool", ) diff --git a/tests/fix_crdb.py b/tests/fix_crdb.py new file mode 100644 index 000000000..e87e4215a --- /dev/null +++ b/tests/fix_crdb.py @@ -0,0 +1,120 @@ +import re + +from .utils import check_version + + +def pytest_configure(config): + # register libpq marker + config.addinivalue_line( + "markers", + "crdb(version_expr, reason=detail): run the test only with matching CockroachDB" + " (e.g. '>= 21.2.10', '< 22.1', 'skip')", + ) + + +def pytest_runtest_setup(item): + for m in item.iter_markers(name="crdb"): + if len(m.args) > 1: + raise TypeError("max one argument expected") + kwargs_unk = set(m.kwargs) - {"reason"} + if kwargs_unk: + raise TypeError(f"unknown keyword arguments: {kwargs_unk}") + + # Copy the want marker on the function so we can check the version + # after the connection has been created. + item.function.want_crdb = m.args[0] if m.args else "only" + item.function.crdb_reason = m.kwargs.get("reason") + + +def check_crdb_version(pgconn, func): + """ + Verify if the CockroachDB version is a version accepted. + + This function is called on the tests marked with something like:: + + @pytest.mark.crdb(">= 21.1") + @pytest.mark.crdb("only") + @pytest.mark.crdb("skip") + + and skips the test if the server version doesn't match what expected. + """ + want = func.want_crdb + got = get_crdb_version(pgconn) + if got is None: + if want == "only": + return "skipping test: CockroachDB only" + else: + if want == "only": + rv = None + elif want == "skip": + rv = "skipping test: not supported on CockroachDB" + else: + rv = check_version(got, want, "CockroachDB") + + if rv: + if func.crdb_reason: + rv = f"{rv}: {func.crdb_reason}" + if func.crdb_reason in crdb_reasons: + url = ( + "https://github.com/cockroachdb/cockroach/" + f"issues/{crdb_reasons[func.crdb_reason]}" + ) + rv = f"{rv} ({url})" + return rv + + +def get_crdb_version(pgconn, __crdb_version=[]): + """ + Return the CockroachDB server version connected. + + Return None if the server is not CockroachDB, else return a number in + the PostgreSQL format (e.g. 21.2.10 -> 200210) + + Assume all the connections are on the same db: return a cached result on + following calls. + """ + if __crdb_version: + return __crdb_version[0] + + bver = pgconn.parameter_status(b"crdb_version") + if bver: + sver = bver.decode() + m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver) + if not m: + raise ValueError(f"can't parse CockroachDB version from {sver}") + + ver = int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3)) + + else: + ver = None + + __crdb_version.append(ver) + + return __crdb_version[0] + + +# mapping from reason description to ticket number +crdb_reasons = { + "2-phase commit": 22329, + "backend pid": 35897, + "batch statements": 44803, + "cancel": 41335, + "cast adds tz": 51692, + "cidr": 18846, + "composite": 27792, + "copy": 41608, + "cursor with hold": 77101, + "deferrable": 48307, + "encoding": 35882, + "hstore": 41284, + "infinity date": 41564, + "interval style": 35807, + "large objects": 243, + "named cursor": 41412, + "nested array": 32552, + "notify": 41522, + "password_encryption": 42519, + "range": 41282, + "scroll cursor": 77102, + "stored procedure": 1751, +} diff --git a/tests/fix_db.py b/tests/fix_db.py index 9fd3111b0..15f9540ce 100644 --- a/tests/fix_db.py +++ b/tests/fix_db.py @@ -6,6 +6,7 @@ from typing import List import psycopg from psycopg import pq + from .utils import check_libpq_version, check_server_version @@ -132,7 +133,7 @@ def pgconn(dsn, request, tracefile): conn = pq.PGconn.connect(dsn.encode()) if conn.status != pq.ConnStatus.OK: pytest.fail(f"bad connection: {conn.error_message.decode('utf8', 'replace')}") - msg = check_connection_version(conn.server_version, request.function) + msg = check_connection_version(conn, request.function) if msg: conn.finish() pytest.skip(msg) @@ -149,7 +150,7 @@ def conn(dsn, request, tracefile): from psycopg import Connection conn = Connection.connect(dsn) - msg = check_connection_version(conn.info.server_version, request.function) + msg = check_connection_version(conn.pgconn, request.function) if msg: conn.close() pytest.skip(msg) @@ -177,7 +178,7 @@ async def aconn(dsn, request, tracefile): from psycopg import AsyncConnection conn = await AsyncConnection.connect(dsn) - msg = check_connection_version(conn.info.server_version, request.function) + msg = check_connection_version(conn.pgconn, request.function) if msg: await conn.close() pytest.skip(msg) @@ -253,10 +254,20 @@ class ListPopAll(list): # type: ignore[type-arg] return out -def check_connection_version(got, function): - if not hasattr(function, "want_pg_version"): - return - return check_server_version(got, function.want_pg_version) +def check_connection_version(pgconn, function): + if hasattr(function, "want_pg_version"): + rv = check_server_version(pgconn, function.want_pg_version) + if rv: + return rv + + if hasattr(function, "want_crdb"): + from .fix_crdb import check_crdb_version + + rv = check_crdb_version(pgconn, function) + if rv: + return rv + + return None @pytest.fixture diff --git a/tests/utils.py b/tests/utils.py index ec1cba446..830eba614 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,10 +17,10 @@ def check_libpq_version(got, want): and skips the test if the requested version doesn't match what's loaded. """ - return _check_version(got, want, "libpq") + return check_version(got, want, "libpq") -def check_server_version(got, want): +def check_server_version(pgconn, want): """ Verify if the server version is a version accepted. @@ -30,10 +30,11 @@ def check_server_version(got, want): and skips the test if the server version doesn't match what expected. """ - return _check_version(got, want, "server") + got = pgconn.server_version + return check_version(got, want, "server") -def _check_version(got, want, whose_version): +def check_version(got, want, whose_version): """Check that a postgres-style version matches a desired spec. - The postgres-style version is a number such as 90603 for 9.6.3.