From: Daniele Varrazzo Date: Sat, 21 Aug 2021 12:39:36 +0000 (+0200) Subject: Add _get_connection_params method to connections X-Git-Tag: 3.0.beta1~55 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=10249a013848ddb6f12b064750ef98b9f8a7f95e;p=thirdparty%2Fpsycopg.git Add _get_connection_params method to connections Move there the connect_timeout extraction logic, but the method is intended to do more elaboration on the parameters before connection, which should include asynchronous DNS lookup and possibly SRV lookup (RFC 2782) and allows overriding in subclasses to allow experimenting. --- diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 0679387a2..fa4b335bb 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -8,7 +8,7 @@ import logging import warnings import threading from types import TracebackType -from typing import Any, Callable, cast, Generic, Iterator, List +from typing import Any, Callable, cast, Dict, Generic, Iterator, List from typing import NamedTuple, Optional, Type, TypeVar, Union from typing import overload, TYPE_CHECKING from weakref import ref, ReferenceType @@ -28,7 +28,7 @@ from .rows import Row, RowFactory, tuple_row, TupleRow from ._enums import IsolationLevel from .cursor import Cursor from ._cmodule import _psycopg -from .conninfo import _conninfo_connect_timeout, ConnectionInfo +from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo from .generators import notifies from ._preparing import PrepareManager from .transaction import Transaction @@ -564,10 +564,12 @@ class Connection(BaseConnection[Row]): """ Connect to a database server and return a new `Connection` instance. """ - conninfo, timeout = _conninfo_connect_timeout(conninfo, **kwargs) + params = cls._get_connection_params(conninfo, **kwargs) + conninfo = make_conninfo(**params) + rv = cls._wait_conn( cls._connect_gen(conninfo, autocommit=autocommit), - timeout, + timeout=params["connect_timeout"], ) if row_factory: rv.row_factory = row_factory @@ -602,6 +604,23 @@ class Connection(BaseConnection[Row]): if not getattr(self, "_pool", None): self.close() + @classmethod + def _get_connection_params( + cls, conninfo: str, **kwargs: Any + ) -> Dict[str, Any]: + """Adjust connection parameters before conecting.""" + params = conninfo_to_dict(conninfo, **kwargs) + + # Make sure there is an usable connect_timeout + if "connect_timeout" in params: + params["connect_timeout"] = int(params["connect_timeout"]) + else: + params["connect_timeout"] = None + + # TODO: SRV lookup (RFC 2782) + + return params + def close(self) -> None: """Close the database connection.""" if self.closed: diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 796716b48..36b6bf2bf 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -8,9 +8,8 @@ import asyncio import logging import warnings from types import TracebackType -from typing import Any, AsyncIterator, cast -from typing import Optional, Type, Union -from typing import overload, TYPE_CHECKING +from typing import Any, AsyncIterator, Dict, Optional, Type, Union +from typing import cast, overload, TYPE_CHECKING from . import errors as e from . import waiting @@ -19,7 +18,7 @@ from .abc import Params, PQGen, PQGenConn, Query, RV from .rows import Row, AsyncRowFactory, tuple_row, TupleRow from ._enums import IsolationLevel from .compat import asynccontextmanager -from .conninfo import _conninfo_connect_timeout +from .conninfo import make_conninfo, conninfo_to_dict from .connection import BaseConnection, CursorRow, Notify from .generators import notifies from .transaction import AsyncTransaction @@ -87,10 +86,12 @@ class AsyncConnection(BaseConnection[Row]): row_factory: Optional[AsyncRowFactory[Row]] = None, **kwargs: Any, ) -> "AsyncConnection[Any]": - conninfo, timeout = _conninfo_connect_timeout(conninfo, **kwargs) + params = await cls._get_connection_params(conninfo, **kwargs) + conninfo = make_conninfo(**params) + rv = await cls._wait_conn( cls._connect_gen(conninfo, autocommit=autocommit), - timeout, + timeout=params["connect_timeout"], ) if row_factory: rv.row_factory = row_factory @@ -125,6 +126,24 @@ class AsyncConnection(BaseConnection[Row]): if not getattr(self, "_pool", None): await self.close() + @classmethod + async def _get_connection_params( + cls, conninfo: str, **kwargs: Any + ) -> Dict[str, Any]: + """Adjust connection parameters before conecting.""" + params = conninfo_to_dict(conninfo, **kwargs) + + # Make sure there is an usable connect_timeout + if "connect_timeout" in params: + params["connect_timeout"] = int(params["connect_timeout"]) + else: + params["connect_timeout"] = None + + # TODO: resolve host names to hostaddr asynchronously + # TODO: SRV lookup (RFC 2782) + + return params + async def close(self) -> None: if self.closed: return diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index c93db7908..058eaf663 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -5,7 +5,7 @@ Functions to manipulate conninfo strings # Copyright (C) 2020-2021 The Psycopg Team import re -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional from pathlib import Path from datetime import tzinfo @@ -49,18 +49,22 @@ def make_conninfo(conninfo: str = "", **kwargs: Any) -> str: return conninfo -def conninfo_to_dict(conninfo: str) -> Dict[str, str]: +def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]: """ Convert the *conninfo* string into a dictionary of parameters. Raise ProgrammingError if the string is not valid. """ opts = _parse_conninfo(conninfo) - return { + rv = { opt.keyword.decode("utf8"): opt.val.decode("utf8") for opt in opts if opt.val is not None } + for k, v in kwargs.items(): + if v is not None: + rv[k] = v + return rv def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]: @@ -95,22 +99,6 @@ def _param_escape(s: str) -> str: return s -def _conninfo_connect_timeout( - conninfo: str, **kwargs: Any -) -> Tuple[str, Optional[int]]: - """ - Build 'conninfo' by combining input value with kwargs and extract - 'connect_timeout' parameter. - """ - conninfo = make_conninfo(conninfo, **kwargs) - connect_timeout: Optional[int] - try: - connect_timeout = int(conninfo_to_dict(conninfo)["connect_timeout"]) - except KeyError: - connect_timeout = None - return conninfo, connect_timeout - - class ConnectionInfo: """Allow access to information about the connection.""" diff --git a/tests/test_connection.py b/tests/test_connection.py index c35d6d9d2..8171881fe 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -11,7 +11,7 @@ from psycopg import encodings from psycopg import Connection, Notify from psycopg.rows import tuple_row from psycopg.errors import UndefinedTable -from psycopg.conninfo import conninfo_to_dict +from psycopg.conninfo import conninfo_to_dict, make_conninfo from .utils import gc_collect from .test_cursor import my_row_factory @@ -690,3 +690,40 @@ def test_set_transaction_param_strange(conn): conn.deferrable = 0 assert conn.deferrable is False + + +conninfo_params_timeout = [ + ( + "", + {"host": "localhost", "connect_timeout": None}, + ({"host": "localhost"}, None), + ), + ( + "", + {"host": "localhost", "connect_timeout": 1}, + ({"host": "localhost", "connect_timeout": "1"}, 1), + ), + ( + "dbname=postgres", + {}, + ({"dbname": "postgres"}, None), + ), + ( + "dbname=postgres connect_timeout=2", + {}, + ({"dbname": "postgres", "connect_timeout": "2"}, 2), + ), + ( + "postgresql:///postgres?connect_timeout=2", + {"connect_timeout": 10}, + ({"dbname": "postgres", "connect_timeout": "10"}, 10), + ), +] + + +@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout) +def test_get_connection_params(dsn, kwargs, exp): + params = Connection._get_connection_params(dsn, **kwargs) + conninfo = make_conninfo(**params) + assert conninfo_to_dict(conninfo) == exp[0] + assert params.get("connect_timeout") == exp[1] diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 8e54f4a10..46ccdc435 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -10,11 +10,11 @@ from psycopg import encodings from psycopg import AsyncConnection, Notify from psycopg.rows import tuple_row from psycopg.errors import UndefinedTable -from psycopg.conninfo import conninfo_to_dict +from psycopg.conninfo import conninfo_to_dict, make_conninfo from .utils import gc_collect from .test_cursor import my_row_factory -from .test_connection import tx_params, tx_values_map +from .test_connection import tx_params, tx_values_map, conninfo_params_timeout pytestmark = pytest.mark.asyncio @@ -696,3 +696,11 @@ async def test_set_transaction_param_strange(aconn): await aconn.set_deferrable(0) assert aconn.deferrable is False + + +@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout) +async def test_get_connection_params(dsn, kwargs, exp): + params = await AsyncConnection._get_connection_params(dsn, **kwargs) + conninfo = make_conninfo(**params) + assert conninfo_to_dict(conninfo) == exp[0] + assert params["connect_timeout"] == exp[1] diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py index 611d72a57..c36f75a7e 100644 --- a/tests/test_conninfo.py +++ b/tests/test_conninfo.py @@ -5,12 +5,7 @@ import pytest import psycopg from psycopg import ProgrammingError -from psycopg.conninfo import ( - _conninfo_connect_timeout, - make_conninfo, - conninfo_to_dict, - ConnectionInfo, -) +from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo snowman = "\u2603" @@ -95,37 +90,6 @@ def test_no_munging(): assert dsnin == dsnout -@pytest.mark.parametrize( - "dsn, kwargs, exp", - [ - ( - "", - {"host": "localhost", "connect_timeout": 1}, - ({"host": "localhost", "connect_timeout": "1"}, 1), - ), - ( - "dbname=postgres", - {}, - ({"dbname": "postgres"}, None), - ), - ( - "dbname=postgres connect_timeout=2", - {}, - ({"dbname": "postgres", "connect_timeout": "2"}, 2), - ), - ( - "postgresql:///postgres?connect_timeout=2", - {"connect_timeout": 10}, - ({"dbname": "postgres", "connect_timeout": "10"}, 10), - ), - ], -) -def test__conninfo_connect_timeout(dsn, kwargs, exp): - conninfo, connect_timeout = _conninfo_connect_timeout(dsn, **kwargs) - assert conninfo_to_dict(conninfo) == exp[0] - assert connect_timeout == exp[1] - - class TestConnectionInfo: @pytest.mark.parametrize( "attr",