From: Daniele Varrazzo Date: Tue, 4 May 2021 16:08:30 +0000 (+0200) Subject: Add pytest marker to skip tests on certain server versions X-Git-Tag: 3.0.dev0~48^2~14 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5cb1a6e59b192d82a38e13864dccb8a3248054b5;p=thirdparty%2Fpsycopg.git Add pytest marker to skip tests on certain server versions --- diff --git a/tests/fix_db.py b/tests/fix_db.py index 1cab75592..ee01caa4a 100644 --- a/tests/fix_db.py +++ b/tests/fix_db.py @@ -1,6 +1,8 @@ import os import pytest +from .utils import check_server_version + def pytest_addoption(parser): parser.addoption( @@ -12,6 +14,15 @@ def pytest_addoption(parser): ) +def pytest_configure(config): + # register pg marker + config.addinivalue_line( + "markers", + "pg(version_expr): run the test only with matching server version" + " (e.g. '>= 10', '< 9.6')", + ) + + @pytest.fixture(scope="session") def dsn(request): """Return the dsn used to connect to the `--test-dsn` database.""" @@ -22,7 +33,7 @@ def dsn(request): @pytest.fixture -def pgconn(dsn): +def pgconn(dsn, request): """Return a PGconn connection open to `--test-dsn`.""" from psycopg3 import pq @@ -31,26 +42,38 @@ def pgconn(dsn): pytest.fail( f"bad connection: {conn.error_message.decode('utf8', 'replace')}" ) + msg = check_connection_version(conn.server_version, request.function) + if msg: + conn.finish() + pytest.skip(msg) yield conn conn.finish() @pytest.fixture -def conn(dsn): +def conn(dsn, request): """Return a `Connection` connected to the ``--test-dsn`` database.""" from psycopg3 import Connection conn = Connection.connect(dsn) + msg = check_connection_version(conn.info.server_version, request.function) + if msg: + conn.close() + pytest.skip(msg) yield conn conn.close() @pytest.fixture -async def aconn(dsn): +async def aconn(dsn, request): """Return an `AsyncConnection` connected to the ``--test-dsn`` database.""" from psycopg3 import AsyncConnection conn = await AsyncConnection.connect(dsn) + msg = check_connection_version(conn.info.server_version, request.function) + if msg: + await conn.close() + pytest.skip(msg) yield conn await conn.close() @@ -107,3 +130,11 @@ class ListPopAll(list): out = self[:] del self[:] return out + + +def check_connection_version(got, function): + if not hasattr(function, "pytestmark"): + return + want = [m.args[0] for m in function.pytestmark if m.name == "pg"] + if want: + return check_server_version(got, want[0]) diff --git a/tests/fix_pq.py b/tests/fix_pq.py index bb5a66ff9..115803f98 100644 --- a/tests/fix_pq.py +++ b/tests/fix_pq.py @@ -1,9 +1,9 @@ -import re -import operator import sys import pytest +from .utils import check_libpq_version + def pytest_report_header(config): try: @@ -30,7 +30,10 @@ def pytest_runtest_setup(item): from psycopg3 import pq for m in item.iter_markers(name="libpq"): - check_libpq_version(pq.version(), m.args) + assert len(m.args) == 1 + msg = check_libpq_version(pq.version(), m.args[0]) + if msg: + pytest.skip(msg) @pytest.fixture @@ -53,55 +56,3 @@ def libpq(): pytest.skip(f"can't load libpq for testing: {e}") else: raise - - -def check_libpq_version(got, want): - """ - Verify if the libpq version is a version accepted. - - This function is called on the tests marked with something like:: - - @pytest.mark.libpq(">= 12") - - and skips the test if the requested version doesn't match what's loaded. - - """ - # convert 90603 to (9, 6, 3), 120003 to (12, 3) - got, got_fix = divmod(got, 100) - got_maj, got_min = divmod(got, 100) - if got_maj >= 10: - got = (got_maj, got_fix) - else: - got = (got_maj, got_min, got_fix) - - # Parse a spec like "> 9.6" - if len(want) != 1: - pytest.fail("libpq marker doesn't specify a version") - want = want[0] - m = re.match( - r"^\s*(>=|<=|>|<)\s*(?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?\s*$", want - ) - if m is None: - pytest.fail(f"bad libpq spec: {want}") - - # convert "9.6" into (9, 6, 0), "10.3" into (10, 3) - want_maj = int(m.group(2)) - want_min = int(m.group(3) or "0") - want_fix = int(m.group(4) or "0") - if want_maj >= 10: - if want_fix: - pytest.fail(f"bad libpq version in {want}") - want = (want_maj, want_min) - else: - want = (want_maj, want_min, want_fix) - - op = getattr( - operator, {">=": "ge", "<=": "le", ">": "gt", "<": "lt"}[m.group(1)] - ) - - if not op(got, want): - revops = {">=": "<", "<=": ">", ">": "<=", "<": ">="} - pytest.skip( - f"skipping test: libpq loaded is {'.'.join(map(str, got))}" - f" {revops[m.group(1)]} {'.'.join(map(str, want))}" - ) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..e235d6915 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,69 @@ +import re +import operator + +import pytest + + +def check_libpq_version(got, want): + """ + Verify if the libpq version is a version accepted. + + This function is called on the tests marked with something like:: + + @pytest.mark.libpq(">= 12") + + and skips the test if the requested version doesn't match what's loaded. + """ + return _check_version(got, want, "libpq") + + +def check_server_version(got, want): + """ + Verify if the server version is a version accepted. + + This function is called on the tests marked with something like:: + + @pytest.mark.pg(">= 12") + + and skips the test if the server version doesn't match what expected. + """ + return _check_version(got, want, "server") + + +def _check_version(got, want, whose_version): + # convert 90603 to (9, 6, 3), 120003 to (12, 3) + got, got_fix = divmod(got, 100) + got_maj, got_min = divmod(got, 100) + if got_maj >= 10: + got = (got_maj, got_fix) + else: + got = (got_maj, got_min, got_fix) + + # Parse a spec like "> 9.6" + m = re.match( + r"^\s*(>=|<=|>|<)\s*(?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?\s*$", want + ) + if m is None: + pytest.fail(f"bad wanted version spec: {want}") + + # convert "9.6" into (9, 6, 0), "10.3" into (10, 3) + want_maj = int(m.group(2)) + want_min = int(m.group(3) or "0") + want_fix = int(m.group(4) or "0") + if want_maj >= 10: + if want_fix: + pytest.fail(f"bad version in {want}") + want = (want_maj, want_min) + else: + want = (want_maj, want_min, want_fix) + + op = getattr( + operator, {">=": "ge", "<=": "le", ">": "gt", "<": "lt"}[m.group(1)] + ) + + if not op(got, want): + revops = {">=": "<", "<=": ">", ">": "<=", "<": ">="} + return ( + f"skipping test: {whose_version} version is {'.'.join(map(str, got))}" + f" {revops[m.group(1)]} {'.'.join(map(str, want))}" + )