]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(capabilities): add caching to capabilities check 782/head
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Apr 2024 12:04:57 +0000 (14:04 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Apr 2024 22:09:08 +0000 (00:09 +0200)
psycopg/psycopg/_capabilities.py
psycopg/psycopg/_pipeline.py
tests/test_capabilities.py

index 0054b9a70d4a57052e70337e7343242f9d90cb5e..44725a8af06b83a593f91695b714e3ce03f88a47 100644 (file)
@@ -4,6 +4,8 @@ psycopg capabilities objects
 
 # Copyright (C) 2024 The Psycopg Team
 
+from typing import Dict
+
 from . import pq
 from . import _cmodule
 from .errors import NotSupportedError
@@ -14,6 +16,9 @@ class Capabilities:
     An object to check if a feature is supported by the libpq available on the client.
     """
 
+    def __init__(self) -> None:
+        self._cache: Dict[str, str] = {}
+
     def has_encrypt_password(self, check: bool = False) -> bool:
         """Check if the `PGconn.encrypt_password()` method is implemented.
 
@@ -67,9 +72,27 @@ class Capabilities:
 
         The expletive messages, are left to the user.
         """
-        msg = ""
+        if feature in self._cache:
+            msg = self._cache[feature]
+        else:
+            msg = self._get_unsupported_message(feature, want_version)
+            self._cache[feature] = msg
+
+        if not msg:
+            return True
+        elif check:
+            raise NotSupportedError(msg)
+        else:
+            return False
+
+    def _get_unsupported_message(self, feature: str, want_version: int) -> str:
+        """
+        Return a descriptinve message to describe why a feature is unsupported.
+
+        Return an empty string if the feature is supported.
+        """
         if pq.version() < want_version:
-            msg = (
+            return (
                 f"the feature '{feature}' is not available:"
                 f" the client libpq version (imported from {self._libpq_source()})"
                 f" is {pq.version_pretty(pq.version())}; the feature"
@@ -78,20 +101,15 @@ class Capabilities:
             )
 
         elif pq.__build_version__ < want_version:
-            msg = (
+            return (
                 f"the feature '{feature}' is not available:"
                 f" you are using a psycopg[{pq.__impl__}] libpq wrapper built"
                 f" with libpq version {pq.version_pretty(pq.__build_version__)};"
                 " the feature requires libpq version"
                 f" {pq.version_pretty(want_version)} or newer"
             )
-
-        if not msg:
-            return True
-        elif check:
-            raise NotSupportedError(msg)
         else:
-            return False
+            return ""
 
     def _libpq_source(self) -> str:
         """Return a string reporting where the libpq comes from."""
index 7c6ce5c82563d563cb54d7fd9a7c34bc9ce9b5ac..0c4671e2bf5fbd5d14d6c894ca4c4086afd17fc4 100644 (file)
@@ -42,7 +42,6 @@ logger = logging.getLogger("psycopg")
 class BasePipeline:
     command_queue: Deque[PipelineCommand]
     result_queue: Deque[PendingResult]
-    _is_supported: Optional[bool] = None
 
     def __init__(self, conn: "BaseConnection[Any]") -> None:
         self._conn = conn
@@ -63,13 +62,10 @@ class BasePipeline:
     @classmethod
     def is_supported(cls) -> bool:
         """Return `!True` if the psycopg libpq wrapper supports pipeline mode."""
-        if BasePipeline._is_supported is None:
-            BasePipeline._is_supported = capabilities.has_pipeline()
-        return BasePipeline._is_supported
+        return capabilities.has_pipeline()
 
     def _enter_gen(self) -> PQGen[None]:
-        if not self._is_supported:
-            capabilities.has_pipeline(check=True)
+        capabilities.has_pipeline(check=True)
         if self.level == 0:
             self.pgconn.enter_pipeline_mode()
         elif self.command_queue or self.pgconn.transaction_status == ACTIVE:
index 6a3a2da9394600109436eb8651a403a5a768446c..56331e4138c9c91859c15f1a7919f405691f1188 100644 (file)
@@ -3,7 +3,7 @@ import re
 import pytest
 
 from psycopg import pq, _cmodule
-from psycopg import capabilities, NotSupportedError
+from psycopg import Capabilities, capabilities, NotSupportedError
 
 caps = [
     ("has_encrypt_password", "pq.PGconn.encrypt_password()", 10),
@@ -46,13 +46,13 @@ def test_build_or_import_msg(monkeypatch):
     monkeypatch.setattr(pq, "version", lambda: 140000)
     monkeypatch.setattr(pq, "__build_version__", 139999)
     with pytest.raises(NotSupportedError, match=r"built with libpq version 13\.99\.99"):
-        capabilities.has_pipeline(check=True)
+        Capabilities().has_pipeline(check=True)
 
     monkeypatch.setattr(pq, "version", lambda: 139999)
     with pytest.raises(
         NotSupportedError, match=r"client libpq version \(.*\) is 13\.99\.99"
     ):
-        capabilities.has_pipeline(check=True)
+        Capabilities().has_pipeline(check=True)
 
 
 def test_impl_build_error(monkeypatch):
@@ -65,4 +65,27 @@ def test_impl_build_error(monkeypatch):
     else:
         msg = "(imported from system libraries)"
         with pytest.raises(NotSupportedError, match=re.escape(msg)):
-            capabilities.has_pipeline(check=True)
+            Capabilities().has_pipeline(check=True)
+
+
+def test_caching(monkeypatch):
+
+    version = 150000
+
+    caps = Capabilities()
+    called = 0
+
+    def ver():
+        nonlocal called
+        called += 1
+        return version
+
+    monkeypatch.setattr(pq, "version", ver)
+    monkeypatch.setattr(pq, "__build_version__", version)
+
+    caps.has_pipeline()
+    assert called == 1
+    caps.has_pipeline()
+    assert called == 1
+    caps.has_hostaddr()
+    assert called == 2