]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: add ConnDict, ConnParam to abc module
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 8 Jan 2024 12:22:55 +0000 (13:22 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 8 Jan 2024 12:22:55 +0000 (13:22 +0100)
Formalize the type of parameter to pass to the libq.

ConnMapping is required when a covariant type is needed.

15 files changed:
psycopg/psycopg/_conninfo_attempts.py
psycopg/psycopg/_conninfo_attempts_async.py
psycopg/psycopg/_conninfo_utils.py
psycopg/psycopg/_encodings.py
psycopg/psycopg/abc.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/conninfo.py
tests/dbapi20.py
tests/fix_db.py
tests/fix_proxy.py
tests/test_conninfo_attempts.py
tests/test_conninfo_attempts_async.py
tests/test_generators.py
tests/test_psycopg_dbapi20.py

index 4fc0f792a33ced351583efa27218a53f7c83603c..6f64f4ba1cc19d2793bea75142ac3c52148f8c04 100644 (file)
@@ -14,14 +14,15 @@ import logging
 from random import shuffle
 
 from . import errors as e
-from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
+from .abc import ConnDict, ConnMapping
+from ._conninfo_utils import get_param, is_ip_address, get_param_def
 from ._conninfo_utils import split_attempts
 
 
 logger = logging.getLogger("psycopg")
 
 
-def conninfo_attempts(params: ConnDict) -> list[ConnDict]:
+def conninfo_attempts(params: ConnMapping) -> list[ConnDict]:
     """Split a set of connection params on the single attempts to perform.
 
     A connection param can perform more than one attempt more than one ``host``
index 6aca4ee3adbf3a6a402eca6f8d414098db18f537..a549081e9578e1b2b6bc00cc26b0aaf65991de12 100644 (file)
@@ -11,7 +11,8 @@ import logging
 from random import shuffle
 
 from . import errors as e
-from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
+from .abc import ConnDict, ConnMapping
+from ._conninfo_utils import get_param, is_ip_address, get_param_def
 from ._conninfo_utils import split_attempts
 
 if True:  # ASYNC:
@@ -20,7 +21,7 @@ if True:  # ASYNC:
 logger = logging.getLogger("psycopg")
 
 
-async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]:
+async def conninfo_attempts_async(params: ConnMapping) -> list[ConnDict]:
     """Split a set of connection params on the single attempts to perform.
 
     A connection param can perform more than one attempt more than one ``host``
index 72e59a41e256cbf1e4e41b647d7386d079d1e9b0..a342987a09e279a41f424a20ec3f74c0c7fd901f 100644 (file)
@@ -13,16 +13,14 @@ from ipaddress import ip_address
 from dataclasses import dataclass
 
 from . import pq
+from .abc import ConnDict, ConnMapping
 from . import errors as e
-from ._compat import TypeAlias
 
 if TYPE_CHECKING:
     from typing import Any  # noqa: F401
 
-ConnDict: TypeAlias = "dict[str, Any]"
 
-
-def split_attempts(params: ConnDict) -> list[ConnDict]:
+def split_attempts(params: ConnMapping) -> list[ConnDict]:
     """
     Split connection parameters with a sequence of hosts into separate attempts.
     """
@@ -50,7 +48,7 @@ def split_attempts(params: ConnDict) -> list[ConnDict]:
 
     # A single attempt to make. Don't mangle the conninfo string.
     if nhosts <= 1:
-        return [params]
+        return [{**params}]
 
     if len(ports) == 1:
         ports *= nhosts
@@ -58,7 +56,7 @@ def split_attempts(params: ConnDict) -> list[ConnDict]:
     # Now all lists are either empty or have the same length
     rv = []
     for i in range(nhosts):
-        attempt = params.copy()
+        attempt = {**params}
         if hosts:
             attempt["host"] = hosts[i]
         if hostaddrs:
@@ -70,7 +68,7 @@ def split_attempts(params: ConnDict) -> list[ConnDict]:
     return rv
 
 
-def get_param(params: ConnDict, name: str) -> str | None:
+def get_param(params: ConnMapping, name: str) -> str | None:
     """
     Return a value from a connection string.
 
index 2382710933d89d82ef9c786e99301968bbfc33b1..4949e26c683da38f294219130bbe4b0e8739a7b5 100644 (file)
@@ -116,7 +116,7 @@ def conninfo_encoding(conninfo: str) -> str:
     pgenc = params.get("client_encoding")
     if pgenc:
         try:
-            return pg2pyenc(pgenc.encode())
+            return pg2pyenc(str(pgenc).encode())
         except NotSupportedError:
             pass
 
index 0080891f85cb4274ce719cf5803b3131bfdabe32..1e0b3e5038fe0bf3945e53482b530d24af8f9944 100644 (file)
@@ -4,7 +4,7 @@ Protocol objects representing different implementations of the same classes.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, Callable, Generator, Mapping
+from typing import Any, Dict, Callable, Generator, Mapping
 from typing import List, Optional, Protocol, Sequence, Tuple, Union
 from typing import TYPE_CHECKING
 
@@ -30,6 +30,10 @@ Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]]
 ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]")
 PipelineCommand: TypeAlias = Callable[[], None]
 DumperKey: TypeAlias = Union[type, Tuple["DumperKey", ...]]
+ConnParam: TypeAlias = Union[str, int, None]
+ConnDict: TypeAlias = Dict[str, ConnParam]
+ConnMapping: TypeAlias = Mapping[str, ConnParam]
+
 
 # Waiting protocol types
 
index dc02ce3815c359a6759710a174e7a145db596a20..12873bbb3856ab3976b811cf52972d26b65bb393 100644 (file)
@@ -18,13 +18,13 @@ from contextlib import contextmanager
 from . import pq
 from . import errors as e
 from . import waiting
-from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
+from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV
 from ._tpc import Xid
 from .rows import Row, RowFactory, tuple_row, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from ._compat import Self
-from .conninfo import ConnDict, make_conninfo, conninfo_to_dict
+from .conninfo import make_conninfo, conninfo_to_dict
 from .conninfo import conninfo_attempts, timeout_from_conninfo
 from ._pipeline import Pipeline
 from ._encodings import pgconn_encoding
@@ -83,7 +83,7 @@ class Connection(BaseConnection[Row]):
         context: Optional[AdaptContext] = None,
         row_factory: Optional[RowFactory[Row]] = None,
         cursor_factory: Optional[Type[Cursor[Row]]] = None,
-        **kwargs: Any,
+        **kwargs: ConnParam,
     ) -> Self:
         """
         Connect to a database server and return a new `Connection` instance.
@@ -95,7 +95,7 @@ class Connection(BaseConnection[Row]):
         attempts = conninfo_attempts(params)
         for attempt in attempts:
             try:
-                conninfo = make_conninfo(**attempt)
+                conninfo = make_conninfo("", **attempt)
                 rv = cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
                 break
             except e._NO_TRACEBACK as ex:
index 46269fda4f6fc43bc2b0684a3bbf6937c01fd87b..2f28fc95305077fe94f0cb0dfa715c159c5265f1 100644 (file)
@@ -15,13 +15,13 @@ from contextlib import asynccontextmanager
 from . import pq
 from . import errors as e
 from . import waiting
-from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
+from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV
 from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from ._compat import Self
-from .conninfo import ConnDict, make_conninfo, conninfo_to_dict
+from .conninfo import make_conninfo, conninfo_to_dict
 from .conninfo import conninfo_attempts_async, timeout_from_conninfo
 from ._pipeline import AsyncPipeline
 from ._encodings import pgconn_encoding
@@ -88,7 +88,7 @@ class AsyncConnection(BaseConnection[Row]):
         context: Optional[AdaptContext] = None,
         row_factory: Optional[AsyncRowFactory[Row]] = None,
         cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
-        **kwargs: Any,
+        **kwargs: ConnParam,
     ) -> Self:
         """
         Connect to a database server and return a new `AsyncConnection` instance.
@@ -110,7 +110,7 @@ class AsyncConnection(BaseConnection[Row]):
         attempts = await conninfo_attempts_async(params)
         for attempt in attempts:
             try:
-                conninfo = make_conninfo(**attempt)
+                conninfo = make_conninfo("", **attempt)
                 rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
                 break
             except e._NO_TRACEBACK as ex:
index 82da5882259057c9b668ff198d6f42a6aa4114b0..1401426b2ebeab32711aceba4db3643754aaa3d6 100644 (file)
@@ -7,17 +7,15 @@ Functions to manipulate conninfo strings
 from __future__ import annotations
 
 import re
-from typing import Any
 
 from . import pq
 from . import errors as e
-
 from . import _conninfo_utils
 from . import _conninfo_attempts
 from . import _conninfo_attempts_async
+from .abc import ConnParam, ConnDict
 
 # re-exoprts
-ConnDict = _conninfo_utils.ConnDict
 conninfo_attempts = _conninfo_attempts.conninfo_attempts
 conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async
 
@@ -27,7 +25,7 @@ conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async
 _DEFAULT_CONNECT_TIMEOUT = 130
 
 
-def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
+def make_conninfo(conninfo: str = "", **kwargs: ConnParam) -> str:
     """
     Merge a string and keyword params into a single conninfo string.
 
@@ -68,7 +66,7 @@ def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
     return conninfo
 
 
-def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict:
+def conninfo_to_dict(conninfo: str = "", **kwargs: ConnParam) -> ConnDict:
     """
     Convert the `!conninfo` string into a dictionary of parameters.
 
@@ -84,7 +82,9 @@ def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict:
            #LIBPQ-CONNSTRING
     """
     opts = _parse_conninfo(conninfo)
-    rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None}
+    rv: ConnDict = {
+        opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None
+    }
     for k, v in kwargs.items():
         if v is not None:
             rv[k] = v
index c873a4e66b63b385e138000a59609207e07a852b..76b0d40339d93c0b29bc202fc6abd618ea5de3b1 100644 (file)
@@ -13,6 +13,8 @@
     -- Ian Bicking
 '''
 
+from __future__ import annotations
+
 __rcs_id__  = '$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $'
 __version__ = '$Revision: 1.12 $'[11:-2]
 __author__ = 'Stuart Bishop <stuart@stuartbishop.net>'
@@ -20,7 +22,10 @@ __author__ = 'Stuart Bishop <stuart@stuartbishop.net>'
 import unittest
 import time
 import sys
-from typing import Any, Dict
+from typing import Any, Dict, TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from psycopg.abc import ConnDict
 
 
 # Revision 1.12  2009/02/06 03:35:11  kf7xm
@@ -101,7 +106,7 @@ class DatabaseAPI20Test(unittest.TestCase):
     # method is to be found
     driver: Any = None
     connect_args = () # List of arguments to pass to connect
-    connect_kw_args: Dict[str, Any] = {} # Keyword arguments for connect
+    connect_kw_args: ConnDict = {} # Keyword arguments for connect
     table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
 
     ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix
index 9abeda79e9b5c13d46a4a4dfe2e3c03c4cd2e98a..37ee7ac3252d117aac3844c4318483207b40cd88 100644 (file)
@@ -119,7 +119,7 @@ def dsn_env(dsn):
             continue
         args[opt.keyword.decode()] = os.environ[opt.envvar.decode()]
 
-    return make_conninfo(**args)
+    return make_conninfo("", **args)
 
 
 @pytest.fixture(scope="session")
index 1d566b5e5dc178aefc104415e0e8932b90f27779..6a74877866887c3f5cbb167172a57072ddedc22c 100644 (file)
@@ -58,7 +58,8 @@ class Proxy:
         cdict = conninfo.conninfo_to_dict(server_dsn)
 
         # Get server params
-        host = cdict.get("host") or os.environ.get("PGHOST")
+        host = cdict.get("host") or os.environ.get("PGHOST", "")
+        assert isinstance(host, str)
         self.server_host = host if host and not host.startswith("/") else "localhost"
         self.server_port = cdict.get("port") or os.environ.get("PGPORT", "5432")
 
@@ -70,7 +71,7 @@ class Proxy:
         cdict["host"] = self.client_host
         cdict["port"] = self.client_port
         cdict["sslmode"] = "disable"  # not supported by the proxy
-        self.client_dsn = conninfo.make_conninfo(**cdict)
+        self.client_dsn = conninfo.make_conninfo("", **cdict)
 
         # The running proxy process
         self.proc = None
index c2855760ac88ec7f1603aa9b91ba6255909a2fa0..0f4ba1b118baafc3c125641f4d5410e19939db60 100644 (file)
@@ -165,14 +165,14 @@ def test_conninfo_random_multi_host():
 
 def test_conninfo_random_multi_ips(fake_resolve):
     args = {"host": "alot.com"}
-    hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+    hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)]
     assert len(hostaddrs) == 20
     assert hostaddrs == sorted(hostaddrs)
 
     args["load_balance_hosts"] = "disable"
-    hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+    hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)]
     assert hostaddrs == sorted(hostaddrs)
 
     args["load_balance_hosts"] = "random"
-    hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+    hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)]
     assert hostaddrs != sorted(hostaddrs)
index bf6da880f4d7a2ad8b0642dd182d9f38caffa0b8..aada9f1e00f77b9c96ad19c685f17d6b41c8d544 100644 (file)
@@ -172,14 +172,14 @@ async def test_conninfo_random_multi_host():
 
 async def test_conninfo_random_multi_ips(fake_resolve):
     args = {"host": "alot.com"}
-    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+    hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)]
     assert len(hostaddrs) == 20
     assert hostaddrs == sorted(hostaddrs)
 
     args["load_balance_hosts"] = "disable"
-    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+    hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)]
     assert hostaddrs == sorted(hostaddrs)
 
     args["load_balance_hosts"] = "random"
-    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+    hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)]
     assert hostaddrs != sorted(hostaddrs)
index ecb8da987182ca7d626f3cb8c0a9c4608f3fc3e6..2df55e3e08eede59d72df371e28a4e3cd741ffa2 100644 (file)
@@ -25,7 +25,7 @@ def test_connect_operationalerror_pgconn(generators, dsn, monkeypatch):
         except KeyError:
             info = conninfo_to_dict(dsn)
             del info["password"]  # should not raise per check above.
-            dsn = make_conninfo(**info)
+            dsn = make_conninfo("", **info)
 
         gen = generators.connect(dsn)
         with pytest.raises(
index a89344974207a7177d4bfd13262f9d1bb7c8b806..b4feac792ab1783cb4429795b715a39f5d0d4a3d 100644 (file)
@@ -1,8 +1,8 @@
 import pytest
 import datetime as dt
-from typing import Any, Dict
 
 import psycopg
+from psycopg.abc import ConnDict
 from psycopg.conninfo import conninfo_to_dict
 
 from . import dbapi20
@@ -18,7 +18,7 @@ def with_dsn(request, session_dsn):
 class PsycopgTests(dbapi20.DatabaseAPI20Test):
     driver = psycopg
     # connect_args = () # set by the fixture
-    connect_kw_args: Dict[str, Any] = {}
+    connect_kw_args: ConnDict = {}
 
     def test_nextset(self):
         # tested elsewhere