import os
import pytest
+from .utils import check_server_version
+
def pytest_addoption(parser):
parser.addoption(
)
+def pytest_configure(config):
+ # register pg marker
+ config.addinivalue_line(
+ "markers",
+ "pg(version_expr): run the test only with matching server version"
+ " (e.g. '>= 10', '< 9.6')",
+ )
+
+
@pytest.fixture(scope="session")
def dsn(request):
"""Return the dsn used to connect to the `--test-dsn` database."""
@pytest.fixture
-def pgconn(dsn):
+def pgconn(dsn, request):
"""Return a PGconn connection open to `--test-dsn`."""
from psycopg3 import pq
pytest.fail(
f"bad connection: {conn.error_message.decode('utf8', 'replace')}"
)
+ msg = check_connection_version(conn.server_version, request.function)
+ if msg:
+ conn.finish()
+ pytest.skip(msg)
yield conn
conn.finish()
@pytest.fixture
-def conn(dsn):
+def conn(dsn, request):
"""Return a `Connection` connected to the ``--test-dsn`` database."""
from psycopg3 import Connection
conn = Connection.connect(dsn)
+ msg = check_connection_version(conn.info.server_version, request.function)
+ if msg:
+ conn.close()
+ pytest.skip(msg)
yield conn
conn.close()
@pytest.fixture
-async def aconn(dsn):
+async def aconn(dsn, request):
"""Return an `AsyncConnection` connected to the ``--test-dsn`` database."""
from psycopg3 import AsyncConnection
conn = await AsyncConnection.connect(dsn)
+ msg = check_connection_version(conn.info.server_version, request.function)
+ if msg:
+ await conn.close()
+ pytest.skip(msg)
yield conn
await conn.close()
out = self[:]
del self[:]
return out
+
+
+def check_connection_version(got, function):
+ if not hasattr(function, "pytestmark"):
+ return
+ want = [m.args[0] for m in function.pytestmark if m.name == "pg"]
+ if want:
+ return check_server_version(got, want[0])
-import re
-import operator
import sys
import pytest
+from .utils import check_libpq_version
+
def pytest_report_header(config):
try:
from psycopg3 import pq
for m in item.iter_markers(name="libpq"):
- check_libpq_version(pq.version(), m.args)
+ assert len(m.args) == 1
+ msg = check_libpq_version(pq.version(), m.args[0])
+ if msg:
+ pytest.skip(msg)
@pytest.fixture
pytest.skip(f"can't load libpq for testing: {e}")
else:
raise
-
-
-def check_libpq_version(got, want):
- """
- Verify if the libpq version is a version accepted.
-
- This function is called on the tests marked with something like::
-
- @pytest.mark.libpq(">= 12")
-
- and skips the test if the requested version doesn't match what's loaded.
-
- """
- # convert 90603 to (9, 6, 3), 120003 to (12, 3)
- got, got_fix = divmod(got, 100)
- got_maj, got_min = divmod(got, 100)
- if got_maj >= 10:
- got = (got_maj, got_fix)
- else:
- got = (got_maj, got_min, got_fix)
-
- # Parse a spec like "> 9.6"
- if len(want) != 1:
- pytest.fail("libpq marker doesn't specify a version")
- want = want[0]
- m = re.match(
- r"^\s*(>=|<=|>|<)\s*(?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?\s*$", want
- )
- if m is None:
- pytest.fail(f"bad libpq spec: {want}")
-
- # convert "9.6" into (9, 6, 0), "10.3" into (10, 3)
- want_maj = int(m.group(2))
- want_min = int(m.group(3) or "0")
- want_fix = int(m.group(4) or "0")
- if want_maj >= 10:
- if want_fix:
- pytest.fail(f"bad libpq version in {want}")
- want = (want_maj, want_min)
- else:
- want = (want_maj, want_min, want_fix)
-
- op = getattr(
- operator, {">=": "ge", "<=": "le", ">": "gt", "<": "lt"}[m.group(1)]
- )
-
- if not op(got, want):
- revops = {">=": "<", "<=": ">", ">": "<=", "<": ">="}
- pytest.skip(
- f"skipping test: libpq loaded is {'.'.join(map(str, got))}"
- f" {revops[m.group(1)]} {'.'.join(map(str, want))}"
- )
--- /dev/null
+import re
+import operator
+
+import pytest
+
+
+def check_libpq_version(got, want):
+ """
+ Verify if the libpq version is a version accepted.
+
+ This function is called on the tests marked with something like::
+
+ @pytest.mark.libpq(">= 12")
+
+ and skips the test if the requested version doesn't match what's loaded.
+ """
+ return _check_version(got, want, "libpq")
+
+
+def check_server_version(got, want):
+ """
+ Verify if the server version is a version accepted.
+
+ This function is called on the tests marked with something like::
+
+ @pytest.mark.pg(">= 12")
+
+ and skips the test if the server version doesn't match what expected.
+ """
+ return _check_version(got, want, "server")
+
+
+def _check_version(got, want, whose_version):
+ # convert 90603 to (9, 6, 3), 120003 to (12, 3)
+ got, got_fix = divmod(got, 100)
+ got_maj, got_min = divmod(got, 100)
+ if got_maj >= 10:
+ got = (got_maj, got_fix)
+ else:
+ got = (got_maj, got_min, got_fix)
+
+ # Parse a spec like "> 9.6"
+ m = re.match(
+ r"^\s*(>=|<=|>|<)\s*(?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?\s*$", want
+ )
+ if m is None:
+ pytest.fail(f"bad wanted version spec: {want}")
+
+ # convert "9.6" into (9, 6, 0), "10.3" into (10, 3)
+ want_maj = int(m.group(2))
+ want_min = int(m.group(3) or "0")
+ want_fix = int(m.group(4) or "0")
+ if want_maj >= 10:
+ if want_fix:
+ pytest.fail(f"bad version in {want}")
+ want = (want_maj, want_min)
+ else:
+ want = (want_maj, want_min, want_fix)
+
+ op = getattr(
+ operator, {">=": "ge", "<=": "le", ">": "gt", "<": "lt"}[m.group(1)]
+ )
+
+ if not op(got, want):
+ revops = {">=": "<", "<=": ">", ">": "<=", "<": ">="}
+ return (
+ f"skipping test: {whose_version} version is {'.'.join(map(str, got))}"
+ f" {revops[m.group(1)]} {'.'.join(map(str, want))}"
+ )