From: Daniele Varrazzo Date: Mon, 8 Jan 2024 12:22:55 +0000 (+0100) Subject: feat: add ConnDict, ConnParam to abc module X-Git-Tag: 3.2.0~95 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c170851ab85f63935c0709ff0bfd7c1e7e9dd0c0;p=thirdparty%2Fpsycopg.git feat: add ConnDict, ConnParam to abc module Formalize the type of parameter to pass to the libq. ConnMapping is required when a covariant type is needed. --- diff --git a/psycopg/psycopg/_conninfo_attempts.py b/psycopg/psycopg/_conninfo_attempts.py index 4fc0f792a..6f64f4ba1 100644 --- a/psycopg/psycopg/_conninfo_attempts.py +++ b/psycopg/psycopg/_conninfo_attempts.py @@ -14,14 +14,15 @@ import logging from random import shuffle from . import errors as e -from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def +from .abc import ConnDict, ConnMapping +from ._conninfo_utils import get_param, is_ip_address, get_param_def from ._conninfo_utils import split_attempts logger = logging.getLogger("psycopg") -def conninfo_attempts(params: ConnDict) -> list[ConnDict]: +def conninfo_attempts(params: ConnMapping) -> list[ConnDict]: """Split a set of connection params on the single attempts to perform. A connection param can perform more than one attempt more than one ``host`` diff --git a/psycopg/psycopg/_conninfo_attempts_async.py b/psycopg/psycopg/_conninfo_attempts_async.py index 6aca4ee3a..a549081e9 100644 --- a/psycopg/psycopg/_conninfo_attempts_async.py +++ b/psycopg/psycopg/_conninfo_attempts_async.py @@ -11,7 +11,8 @@ import logging from random import shuffle from . import errors as e -from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def +from .abc import ConnDict, ConnMapping +from ._conninfo_utils import get_param, is_ip_address, get_param_def from ._conninfo_utils import split_attempts if True: # ASYNC: @@ -20,7 +21,7 @@ if True: # ASYNC: logger = logging.getLogger("psycopg") -async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]: +async def conninfo_attempts_async(params: ConnMapping) -> list[ConnDict]: """Split a set of connection params on the single attempts to perform. A connection param can perform more than one attempt more than one ``host`` diff --git a/psycopg/psycopg/_conninfo_utils.py b/psycopg/psycopg/_conninfo_utils.py index 72e59a41e..a342987a0 100644 --- a/psycopg/psycopg/_conninfo_utils.py +++ b/psycopg/psycopg/_conninfo_utils.py @@ -13,16 +13,14 @@ from ipaddress import ip_address from dataclasses import dataclass from . import pq +from .abc import ConnDict, ConnMapping from . import errors as e -from ._compat import TypeAlias if TYPE_CHECKING: from typing import Any # noqa: F401 -ConnDict: TypeAlias = "dict[str, Any]" - -def split_attempts(params: ConnDict) -> list[ConnDict]: +def split_attempts(params: ConnMapping) -> list[ConnDict]: """ Split connection parameters with a sequence of hosts into separate attempts. """ @@ -50,7 +48,7 @@ def split_attempts(params: ConnDict) -> list[ConnDict]: # A single attempt to make. Don't mangle the conninfo string. if nhosts <= 1: - return [params] + return [{**params}] if len(ports) == 1: ports *= nhosts @@ -58,7 +56,7 @@ def split_attempts(params: ConnDict) -> list[ConnDict]: # Now all lists are either empty or have the same length rv = [] for i in range(nhosts): - attempt = params.copy() + attempt = {**params} if hosts: attempt["host"] = hosts[i] if hostaddrs: @@ -70,7 +68,7 @@ def split_attempts(params: ConnDict) -> list[ConnDict]: return rv -def get_param(params: ConnDict, name: str) -> str | None: +def get_param(params: ConnMapping, name: str) -> str | None: """ Return a value from a connection string. diff --git a/psycopg/psycopg/_encodings.py b/psycopg/psycopg/_encodings.py index 238271093..4949e26c6 100644 --- a/psycopg/psycopg/_encodings.py +++ b/psycopg/psycopg/_encodings.py @@ -116,7 +116,7 @@ def conninfo_encoding(conninfo: str) -> str: pgenc = params.get("client_encoding") if pgenc: try: - return pg2pyenc(pgenc.encode()) + return pg2pyenc(str(pgenc).encode()) except NotSupportedError: pass diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py index 0080891f8..1e0b3e503 100644 --- a/psycopg/psycopg/abc.py +++ b/psycopg/psycopg/abc.py @@ -4,7 +4,7 @@ Protocol objects representing different implementations of the same classes. # Copyright (C) 2020 The Psycopg Team -from typing import Any, Callable, Generator, Mapping +from typing import Any, Dict, Callable, Generator, Mapping from typing import List, Optional, Protocol, Sequence, Tuple, Union from typing import TYPE_CHECKING @@ -30,6 +30,10 @@ Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]] ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]") PipelineCommand: TypeAlias = Callable[[], None] DumperKey: TypeAlias = Union[type, Tuple["DumperKey", ...]] +ConnParam: TypeAlias = Union[str, int, None] +ConnDict: TypeAlias = Dict[str, ConnParam] +ConnMapping: TypeAlias = Mapping[str, ConnParam] + # Waiting protocol types diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index dc02ce381..12873bbb3 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -18,13 +18,13 @@ from contextlib import contextmanager from . import pq from . import errors as e from . import waiting -from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV +from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV from ._tpc import Xid from .rows import Row, RowFactory, tuple_row, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel from ._compat import Self -from .conninfo import ConnDict, make_conninfo, conninfo_to_dict +from .conninfo import make_conninfo, conninfo_to_dict from .conninfo import conninfo_attempts, timeout_from_conninfo from ._pipeline import Pipeline from ._encodings import pgconn_encoding @@ -83,7 +83,7 @@ class Connection(BaseConnection[Row]): context: Optional[AdaptContext] = None, row_factory: Optional[RowFactory[Row]] = None, cursor_factory: Optional[Type[Cursor[Row]]] = None, - **kwargs: Any, + **kwargs: ConnParam, ) -> Self: """ Connect to a database server and return a new `Connection` instance. @@ -95,7 +95,7 @@ class Connection(BaseConnection[Row]): attempts = conninfo_attempts(params) for attempt in attempts: try: - conninfo = make_conninfo(**attempt) + conninfo = make_conninfo("", **attempt) rv = cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout) break except e._NO_TRACEBACK as ex: diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 46269fda4..2f28fc953 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -15,13 +15,13 @@ from contextlib import asynccontextmanager from . import pq from . import errors as e from . import waiting -from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV +from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV from ._tpc import Xid from .rows import Row, AsyncRowFactory, tuple_row, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel from ._compat import Self -from .conninfo import ConnDict, make_conninfo, conninfo_to_dict +from .conninfo import make_conninfo, conninfo_to_dict from .conninfo import conninfo_attempts_async, timeout_from_conninfo from ._pipeline import AsyncPipeline from ._encodings import pgconn_encoding @@ -88,7 +88,7 @@ class AsyncConnection(BaseConnection[Row]): context: Optional[AdaptContext] = None, row_factory: Optional[AsyncRowFactory[Row]] = None, cursor_factory: Optional[Type[AsyncCursor[Row]]] = None, - **kwargs: Any, + **kwargs: ConnParam, ) -> Self: """ Connect to a database server and return a new `AsyncConnection` instance. @@ -110,7 +110,7 @@ class AsyncConnection(BaseConnection[Row]): attempts = await conninfo_attempts_async(params) for attempt in attempts: try: - conninfo = make_conninfo(**attempt) + conninfo = make_conninfo("", **attempt) rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout) break except e._NO_TRACEBACK as ex: diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index 82da58822..1401426b2 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -7,17 +7,15 @@ Functions to manipulate conninfo strings from __future__ import annotations import re -from typing import Any from . import pq from . import errors as e - from . import _conninfo_utils from . import _conninfo_attempts from . import _conninfo_attempts_async +from .abc import ConnParam, ConnDict # re-exoprts -ConnDict = _conninfo_utils.ConnDict conninfo_attempts = _conninfo_attempts.conninfo_attempts conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async @@ -27,7 +25,7 @@ conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async _DEFAULT_CONNECT_TIMEOUT = 130 -def make_conninfo(conninfo: str = "", **kwargs: Any) -> str: +def make_conninfo(conninfo: str = "", **kwargs: ConnParam) -> str: """ Merge a string and keyword params into a single conninfo string. @@ -68,7 +66,7 @@ def make_conninfo(conninfo: str = "", **kwargs: Any) -> str: return conninfo -def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict: +def conninfo_to_dict(conninfo: str = "", **kwargs: ConnParam) -> ConnDict: """ Convert the `!conninfo` string into a dictionary of parameters. @@ -84,7 +82,9 @@ def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict: #LIBPQ-CONNSTRING """ opts = _parse_conninfo(conninfo) - rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None} + rv: ConnDict = { + opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None + } for k, v in kwargs.items(): if v is not None: rv[k] = v diff --git a/tests/dbapi20.py b/tests/dbapi20.py index c873a4e66..76b0d4033 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -13,6 +13,8 @@ -- Ian Bicking ''' +from __future__ import annotations + __rcs_id__ = '$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $' __version__ = '$Revision: 1.12 $'[11:-2] __author__ = 'Stuart Bishop ' @@ -20,7 +22,10 @@ __author__ = 'Stuart Bishop ' import unittest import time import sys -from typing import Any, Dict +from typing import Any, Dict, TYPE_CHECKING + +if TYPE_CHECKING: + from psycopg.abc import ConnDict # Revision 1.12 2009/02/06 03:35:11 kf7xm @@ -101,7 +106,7 @@ class DatabaseAPI20Test(unittest.TestCase): # method is to be found driver: Any = None connect_args = () # List of arguments to pass to connect - connect_kw_args: Dict[str, Any] = {} # Keyword arguments for connect + connect_kw_args: ConnDict = {} # Keyword arguments for connect table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix diff --git a/tests/fix_db.py b/tests/fix_db.py index 9abeda79e..37ee7ac32 100644 --- a/tests/fix_db.py +++ b/tests/fix_db.py @@ -119,7 +119,7 @@ def dsn_env(dsn): continue args[opt.keyword.decode()] = os.environ[opt.envvar.decode()] - return make_conninfo(**args) + return make_conninfo("", **args) @pytest.fixture(scope="session") diff --git a/tests/fix_proxy.py b/tests/fix_proxy.py index 1d566b5e5..6a7487786 100644 --- a/tests/fix_proxy.py +++ b/tests/fix_proxy.py @@ -58,7 +58,8 @@ class Proxy: cdict = conninfo.conninfo_to_dict(server_dsn) # Get server params - host = cdict.get("host") or os.environ.get("PGHOST") + host = cdict.get("host") or os.environ.get("PGHOST", "") + assert isinstance(host, str) self.server_host = host if host and not host.startswith("/") else "localhost" self.server_port = cdict.get("port") or os.environ.get("PGPORT", "5432") @@ -70,7 +71,7 @@ class Proxy: cdict["host"] = self.client_host cdict["port"] = self.client_port cdict["sslmode"] = "disable" # not supported by the proxy - self.client_dsn = conninfo.make_conninfo(**cdict) + self.client_dsn = conninfo.make_conninfo("", **cdict) # The running proxy process self.proc = None diff --git a/tests/test_conninfo_attempts.py b/tests/test_conninfo_attempts.py index c2855760a..0f4ba1b11 100644 --- a/tests/test_conninfo_attempts.py +++ b/tests/test_conninfo_attempts.py @@ -165,14 +165,14 @@ def test_conninfo_random_multi_host(): def test_conninfo_random_multi_ips(fake_resolve): args = {"host": "alot.com"} - hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)] + hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)] assert len(hostaddrs) == 20 assert hostaddrs == sorted(hostaddrs) args["load_balance_hosts"] = "disable" - hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)] + hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)] assert hostaddrs == sorted(hostaddrs) args["load_balance_hosts"] = "random" - hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)] + hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)] assert hostaddrs != sorted(hostaddrs) diff --git a/tests/test_conninfo_attempts_async.py b/tests/test_conninfo_attempts_async.py index bf6da880f..aada9f1e0 100644 --- a/tests/test_conninfo_attempts_async.py +++ b/tests/test_conninfo_attempts_async.py @@ -172,14 +172,14 @@ async def test_conninfo_random_multi_host(): async def test_conninfo_random_multi_ips(fake_resolve): args = {"host": "alot.com"} - hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] + hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)] assert len(hostaddrs) == 20 assert hostaddrs == sorted(hostaddrs) args["load_balance_hosts"] = "disable" - hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] + hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)] assert hostaddrs == sorted(hostaddrs) args["load_balance_hosts"] = "random" - hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] + hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)] assert hostaddrs != sorted(hostaddrs) diff --git a/tests/test_generators.py b/tests/test_generators.py index ecb8da987..2df55e3e0 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -25,7 +25,7 @@ def test_connect_operationalerror_pgconn(generators, dsn, monkeypatch): except KeyError: info = conninfo_to_dict(dsn) del info["password"] # should not raise per check above. - dsn = make_conninfo(**info) + dsn = make_conninfo("", **info) gen = generators.connect(dsn) with pytest.raises( diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py index a89344974..b4feac792 100644 --- a/tests/test_psycopg_dbapi20.py +++ b/tests/test_psycopg_dbapi20.py @@ -1,8 +1,8 @@ import pytest import datetime as dt -from typing import Any, Dict import psycopg +from psycopg.abc import ConnDict from psycopg.conninfo import conninfo_to_dict from . import dbapi20 @@ -18,7 +18,7 @@ def with_dsn(request, session_dsn): class PsycopgTests(dbapi20.DatabaseAPI20Test): driver = psycopg # connect_args = () # set by the fixture - connect_kw_args: Dict[str, Any] = {} + connect_kw_args: ConnDict = {} def test_nextset(self): # tested elsewhere