]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: add psycopg.capability
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 9 Apr 2024 20:30:43 +0000 (22:30 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Apr 2024 22:09:08 +0000 (00:09 +0200)
Close #772

psycopg/psycopg/__init__.py
psycopg/psycopg/_capabilities.py [new file with mode: 0644]
tests/test_capabilities.py [new file with mode: 0644]

index 55d8c5a591127f7da78edcc91886f8de4905e6ef..57d4243fca5221c5892502164274548e2039961d 100644 (file)
@@ -21,6 +21,7 @@ from ._pipeline import Pipeline, AsyncPipeline
 from .connection import Connection
 from .transaction import Rollback, Transaction, AsyncTransaction
 from .cursor_async import AsyncCursor
+from ._capabilities import Capabilities
 from .server_cursor import AsyncServerCursor, ServerCursor
 from .client_cursor import AsyncClientCursor, ClientCursor
 from .raw_cursor import AsyncRawCursor, RawCursor
@@ -40,6 +41,9 @@ logger = logging.getLogger("psycopg")
 if logger.level == logging.NOTSET:
     logger.setLevel(logging.WARNING)
 
+# A global object to check for capabilities.
+capabilities = Capabilities()
+
 # DBAPI compliance
 connect = Connection.connect
 apilevel = "2.0"
diff --git a/psycopg/psycopg/_capabilities.py b/psycopg/psycopg/_capabilities.py
new file mode 100644 (file)
index 0000000..6b67f49
--- /dev/null
@@ -0,0 +1,89 @@
+"""
+psycopg capabilities objects
+"""
+
+# Copyright (C) 2024 The Psycopg Team
+
+from . import pq
+from . import _cmodule
+from .errors import NotSupportedError
+
+
+class Capabilities:
+    """
+    Check if a feature is supported.
+
+    Every feature check is implemented by a check method `has_SOMETHING`.
+    All the methods return a boolean stating if the capability is supported.
+    If not supported and `check` is True, raise a `NotSupportedError` instead
+    explaining why the feature is not supported.
+    """
+
+    def has_encrypt_password(self, check: bool = False) -> bool:
+        """Check if the `~PGconn.encrypt_password()` method is implemented."""
+        return self._has_feature("PGconn.encrypt_password()", 100000, check=check)
+
+    def has_hostaddr(self, check: bool = False) -> bool:
+        """Check if the `~ConnectionInfo.hostaddr` attribute is implemented."""
+        return self._has_feature("Connection.pipeline()", 120000, check=check)
+
+    def has_pipeline(self, check: bool = False) -> bool:
+        """Check if the `~Connection.pipeline()` method is implemented."""
+        return self._has_feature("Connection.pipeline()", 140000, check=check)
+
+    def has_set_trace_flag(self, check: bool = False) -> bool:
+        """Check if the `~PGconn.set_trace_flag()` method is implemented."""
+        return self._has_feature("PGconn.set_trace_flag()", 140000, check=check)
+
+    def has_cancel_safe(self, check: bool = False) -> bool:
+        """Check if the `Connection.cancel_safe()` method is implemented."""
+        return self._has_feature("Connection.cancel_safe()", 170000, check=check)
+
+    def has_pgbouncer_prepared(self, check: bool = False) -> bool:
+        """Check if prepared statements in PgBouncer are supported."""
+        return self._has_feature(
+            "PgBouncer prepared statements compatibility", 170000, check=check
+        )
+
+    def _has_feature(self, feature: str, want_version: int, check: bool) -> bool:
+        """
+        Check is a version is supported.
+
+        If `check` is true, raise an exception with an explicative message
+        explaining why the feature is not supported.
+
+        The expletive messages, are left to the user.
+        """
+        msg = ""
+        if pq.version() < want_version:
+            msg = (
+                f"the feature '{feature}' is not available:"
+                f" the client libpq version (imported from {self._libpq_source()})"
+                f" is {pq.version_pretty(pq.version())}; the feature"
+                f" requires libpq version {pq.version_pretty(want_version)}"
+                " or newer"
+            )
+
+        elif pq.__build_version__ < want_version:
+            msg = (
+                f"the feature '{feature}' is not available:"
+                f" you are using a psycopg[{pq.__impl__}] libpq wrapper built"
+                f" with libpq version {pq.version_pretty(pq.__build_version__)};"
+                " the feature requires libpq version"
+                f" {pq.version_pretty(want_version)} or newer"
+            )
+
+        if not msg:
+            return True
+        elif check:
+            raise NotSupportedError(msg)
+        else:
+            return False
+
+    def _libpq_source(self) -> str:
+        """Return a string reporting where the libpq comes from."""
+        if pq.__impl__ == "binary":
+            version: str = _cmodule.__version__ or "unknown"
+            return f"the psycopg[binary] package version {version}"
+        else:
+            return "system libraries"
diff --git a/tests/test_capabilities.py b/tests/test_capabilities.py
new file mode 100644 (file)
index 0000000..b0e3c2b
--- /dev/null
@@ -0,0 +1,68 @@
+import re
+
+import pytest
+
+from psycopg import pq, _cmodule
+from psycopg import capabilities, NotSupportedError
+
+caps = [
+    ("has_encrypt_password", "encrypt_password", 10),
+    ("has_hostaddr", "PGconn.hostaddr", 12),
+    ("has_pipeline", "Connection.pipeline()", 14),
+    ("has_set_trace_flag", "PGconn.set_trace_flag()", 14),
+    ("has_cancel_safe", "Connection.cancel_safe()", 17),
+    ("has_pgbouncer_prepared", "PgBouncer prepared statements compatibility", 17),
+]
+
+
+@pytest.mark.parametrize(
+    "method_name",
+    [
+        pytest.param(method_name, marks=pytest.mark.libpq(f">= {min_ver}"))
+        for method_name, _, min_ver in caps
+    ],
+)
+def test_has_capability(method_name):
+    method = getattr(capabilities, method_name)
+    assert method()
+    assert method(check=True)
+
+
+@pytest.mark.parametrize(
+    "method_name, label",
+    [
+        pytest.param(method_name, label, marks=pytest.mark.libpq(f"< {min_ver}"))
+        for method_name, label, min_ver in caps
+    ],
+)
+def test_no_capability(method_name, label):
+    method = getattr(capabilities, method_name)
+    assert not method()
+    with pytest.raises(NotSupportedError, match=f"'{re.escape(label)}'"):
+        method(check=True)
+
+
+def test_build_or_import_msg(monkeypatch):
+    monkeypatch.setattr(pq, "version", lambda: 140000)
+    monkeypatch.setattr(pq, "__build_version__", 139999)
+    with pytest.raises(NotSupportedError, match=r"built with libpq version 13\.99\.99"):
+        capabilities.has_pipeline(check=True)
+
+    monkeypatch.setattr(pq, "version", lambda: 139999)
+    with pytest.raises(
+        NotSupportedError, match=r"client libpq version \(.*\) is 13\.99\.99"
+    ):
+        capabilities.has_pipeline(check=True)
+
+
+def test_impl_build_error(monkeypatch):
+    monkeypatch.setattr(pq, "__build_version__", 139999)
+    monkeypatch.setattr(pq, "version", lambda: 139999)
+    if pq.__impl__ == "binary":
+        ver = _cmodule.__version__
+        assert ver
+        msg = "(imported from the psycopg[binary] package version {ver})"
+    else:
+        msg = "(imported from system libraries)"
+        with pytest.raises(NotSupportedError, match=re.escape(msg)):
+            capabilities.has_pipeline(check=True)