]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added some flesh to the connection.info object
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 28 Dec 2020 03:25:08 +0000 (04:25 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 22 Apr 2021 10:57:00 +0000 (11:57 +0100)
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/conninfo.py
tests/test_conninfo.py

index 7492d753c98a7c320fbaefda28bc1935324d4a2d..f927fb0b04ed2fef451ea473eb700c3bc4697761 100644 (file)
@@ -26,7 +26,7 @@ from .rows import tuple_row
 from .proto import PQGen, PQGenConn, RV, RowFactory, Query, Params
 from .proto import AdaptContext, ConnectionType
 from .cursor import Cursor, AsyncCursor
-from .conninfo import make_conninfo
+from .conninfo import make_conninfo, ConnectionInfo
 from .generators import notifies
 from ._preparing import PrepareManager
 from .transaction import Transaction, AsyncTransaction
@@ -215,6 +215,10 @@ class BaseConnection(AdaptContext):
         if result.status != ExecStatus.TUPLES_OK:
             raise e.error_from_result(result, encoding=self.client_encoding)
 
+    @property
+    def info(self) -> ConnectionInfo:
+        return ConnectionInfo(self.pgconn)
+
     @property
     def adapters(self) -> adapt.AdaptersMap:
         return self._adapters
index ccc7282776f18b6c44e5cbdfc4be778554158bff..7e4148166922efd34485057bac2772707c56382c 100644 (file)
@@ -6,9 +6,11 @@ Functions to manipulate conninfo strings
 
 import re
 from typing import Any, Dict, List
+from pathlib import Path
 
 from . import pq
 from . import errors as e
+from . import encodings
 
 
 def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
@@ -89,3 +91,81 @@ def _param_escape(s: str) -> str:
         s = "'" + s + "'"
 
     return s
+
+
+class ConnectionInfo:
+    def __init__(self, pgconn: pq.proto.PGconn):
+        self.pgconn = pgconn
+
+    @property
+    def host(self) -> str:
+        return self._get_pgconn_attr("host")
+
+    @property
+    def port(self) -> int:
+        return int(self._get_pgconn_attr("port"))
+
+    @property
+    def dbname(self) -> str:
+        return self._get_pgconn_attr("db")
+
+    @property
+    def user(self) -> str:
+        return self._get_pgconn_attr("user")
+
+    @property
+    def password(self) -> str:
+        return self._get_pgconn_attr("password")
+
+    @property
+    def options(self) -> str:
+        return self._get_pgconn_attr("options")
+
+    def get_parameters(self) -> Dict[str, str]:
+        """Return the connection parameters values.
+
+        Return all the parameters set to a non-default value, which might come
+        either from the connection string or from environment variables. Don't
+        report the password (you can read it using the `password` attribute).
+        """
+        pyenc = self._pyenc
+
+        # Get the known defaults to avoid reporting them
+        defaults = {
+            i.keyword: i.compiled
+            for i in pq.Conninfo.get_defaults()
+            if i.compiled
+        }
+        # Not returned by the libq. Bug? Bet we're using SSH.
+        defaults.setdefault(b"channel_binding", b"prefer")
+        defaults[b"passfile"] = str(Path.home() / ".pgpass").encode("utf-8")
+
+        return {
+            i.keyword.decode(pyenc): i.val.decode(pyenc)
+            for i in self.pgconn.info
+            if i.val is not None
+            and i.keyword != b"password"
+            and i.val != defaults.get(i.keyword)
+        }
+
+    @property
+    def status(self) -> pq.ConnStatus:
+        return pq.ConnStatus(self.pgconn.status)
+
+    @property
+    def transaction_status(self) -> pq.TransactionStatus:
+        return pq.TransactionStatus(self.pgconn.transaction_status)
+
+    def _get_pgconn_attr(self, name: str) -> str:
+        value: bytes
+        try:
+            value = getattr(self.pgconn, name)
+        except pq.PQerror as exc:
+            raise e.OperationalError(str(exc))
+
+        return value.decode(self._pyenc)
+
+    @property
+    def _pyenc(self) -> str:
+        pgenc = self.pgconn.parameter_status(b"client_encoding") or b"UTF8"
+        return encodings.pg2py(pgenc)
index 2b045c1468d8e425111bba1819c73ed0d550bd57..5b23a05392a49020adf12848e24d40515a05fdb7 100644 (file)
@@ -1,7 +1,8 @@
 import pytest
 
-from psycopg3.conninfo import make_conninfo, conninfo_to_dict
+import psycopg3
 from psycopg3 import ProgrammingError
+from psycopg3.conninfo import make_conninfo, conninfo_to_dict
 
 snowman = "\u2603"
 
@@ -84,3 +85,57 @@ def test_no_munging():
     dsnin = "dbname=a host=b user=c password=d"
     dsnout = make_conninfo(dsnin)
     assert dsnin == dsnout
+
+
+class TestConnectionInfo:
+    @pytest.mark.parametrize(
+        "attr", [("dbname", "db"), "host", "user", "password", "options"]
+    )
+    def test_attrs(self, conn, attr):
+        if isinstance(attr, tuple):
+            info_attr, pgconn_attr = attr
+        else:
+            info_attr = pgconn_attr = attr
+
+        info_val = getattr(conn.info, info_attr)
+        pgconn_val = getattr(conn.pgconn, pgconn_attr).decode("utf-8")
+        assert info_val == pgconn_val
+
+        conn.close()
+        with pytest.raises(psycopg3.OperationalError):
+            getattr(conn.info, info_attr)
+
+    def test_port(self, conn):
+        assert conn.info.port == int(conn.pgconn.port.decode("utf-8"))
+        conn.close()
+        with pytest.raises(psycopg3.OperationalError):
+            conn.info.port
+
+    def test_get_params(self, conn, dsn):
+        info = conn.info.get_parameters()
+        for k, v in conninfo_to_dict(dsn).items():
+            assert info.get(k) == v
+
+    def test_get_params_env(self, dsn, monkeypatch):
+        dsn = conninfo_to_dict(dsn)
+        dsn.pop("application_name", None)
+
+        monkeypatch.delenv("PGAPPNAME")
+        with psycopg3.connect(**dsn) as conn:
+            assert "application_name" not in conn.info.get_parameters()
+
+        monkeypatch.setenv("PGAPPNAME", "hello test")
+        with psycopg3.connect(**dsn) as conn:
+            assert (
+                conn.info.get_parameters()["application_name"] == "hello test"
+            )
+
+    def test_status(self, conn):
+        assert conn.info.status.name == "OK"
+        conn.close()
+        assert conn.info.status.name == "BAD"
+
+    def test_transaction_status(self, conn):
+        assert conn.info.transaction_status.name == "IDLE"
+        conn.close()
+        assert conn.info.transaction_status.name == "UNKNOWN"