]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add _get_connection_params method to connections
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 21 Aug 2021 12:39:36 +0000 (14:39 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Aug 2021 02:00:19 +0000 (04:00 +0200)
Move there the connect_timeout extraction logic, but the method is
intended to do more elaboration on the parameters before connection,
which should include asynchronous DNS lookup and possibly SRV lookup
(RFC 2782) and allows overriding in subclasses to allow experimenting.

psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/conninfo.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_conninfo.py

index 0679387a23f50707f30152a0e00ff5bcdb57276f..fa4b335bbfe3549e72975278489e9f44a1769950 100644 (file)
@@ -8,7 +8,7 @@ import logging
 import warnings
 import threading
 from types import TracebackType
-from typing import Any, Callable, cast, Generic, Iterator, List
+from typing import Any, Callable, cast, Dict, Generic, Iterator, List
 from typing import NamedTuple, Optional, Type, TypeVar, Union
 from typing import overload, TYPE_CHECKING
 from weakref import ref, ReferenceType
@@ -28,7 +28,7 @@ from .rows import Row, RowFactory, tuple_row, TupleRow
 from ._enums import IsolationLevel
 from .cursor import Cursor
 from ._cmodule import _psycopg
-from .conninfo import _conninfo_connect_timeout, ConnectionInfo
+from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
 from .generators import notifies
 from ._preparing import PrepareManager
 from .transaction import Transaction
@@ -564,10 +564,12 @@ class Connection(BaseConnection[Row]):
         """
         Connect to a database server and return a new `Connection` instance.
         """
-        conninfo, timeout = _conninfo_connect_timeout(conninfo, **kwargs)
+        params = cls._get_connection_params(conninfo, **kwargs)
+        conninfo = make_conninfo(**params)
+
         rv = cls._wait_conn(
             cls._connect_gen(conninfo, autocommit=autocommit),
-            timeout,
+            timeout=params["connect_timeout"],
         )
         if row_factory:
             rv.row_factory = row_factory
@@ -602,6 +604,23 @@ class Connection(BaseConnection[Row]):
         if not getattr(self, "_pool", None):
             self.close()
 
+    @classmethod
+    def _get_connection_params(
+        cls, conninfo: str, **kwargs: Any
+    ) -> Dict[str, Any]:
+        """Adjust connection parameters before conecting."""
+        params = conninfo_to_dict(conninfo, **kwargs)
+
+        # Make sure there is an usable connect_timeout
+        if "connect_timeout" in params:
+            params["connect_timeout"] = int(params["connect_timeout"])
+        else:
+            params["connect_timeout"] = None
+
+        # TODO: SRV lookup (RFC 2782)
+
+        return params
+
     def close(self) -> None:
         """Close the database connection."""
         if self.closed:
index 796716b481325775a920e780b57eeca79b0ae436..36b6bf2bfb526da1d82a4dd5ba84c09baa553efe 100644 (file)
@@ -8,9 +8,8 @@ import asyncio
 import logging
 import warnings
 from types import TracebackType
-from typing import Any, AsyncIterator, cast
-from typing import Optional, Type, Union
-from typing import overload, TYPE_CHECKING
+from typing import Any, AsyncIterator, Dict, Optional, Type, Union
+from typing import cast, overload, TYPE_CHECKING
 
 from . import errors as e
 from . import waiting
@@ -19,7 +18,7 @@ from .abc import Params, PQGen, PQGenConn, Query, RV
 from .rows import Row, AsyncRowFactory, tuple_row, TupleRow
 from ._enums import IsolationLevel
 from .compat import asynccontextmanager
-from .conninfo import _conninfo_connect_timeout
+from .conninfo import make_conninfo, conninfo_to_dict
 from .connection import BaseConnection, CursorRow, Notify
 from .generators import notifies
 from .transaction import AsyncTransaction
@@ -87,10 +86,12 @@ class AsyncConnection(BaseConnection[Row]):
         row_factory: Optional[AsyncRowFactory[Row]] = None,
         **kwargs: Any,
     ) -> "AsyncConnection[Any]":
-        conninfo, timeout = _conninfo_connect_timeout(conninfo, **kwargs)
+        params = await cls._get_connection_params(conninfo, **kwargs)
+        conninfo = make_conninfo(**params)
+
         rv = await cls._wait_conn(
             cls._connect_gen(conninfo, autocommit=autocommit),
-            timeout,
+            timeout=params["connect_timeout"],
         )
         if row_factory:
             rv.row_factory = row_factory
@@ -125,6 +126,24 @@ class AsyncConnection(BaseConnection[Row]):
         if not getattr(self, "_pool", None):
             await self.close()
 
+    @classmethod
+    async def _get_connection_params(
+        cls, conninfo: str, **kwargs: Any
+    ) -> Dict[str, Any]:
+        """Adjust connection parameters before conecting."""
+        params = conninfo_to_dict(conninfo, **kwargs)
+
+        # Make sure there is an usable connect_timeout
+        if "connect_timeout" in params:
+            params["connect_timeout"] = int(params["connect_timeout"])
+        else:
+            params["connect_timeout"] = None
+
+        # TODO: resolve host names to hostaddr asynchronously
+        # TODO: SRV lookup (RFC 2782)
+
+        return params
+
     async def close(self) -> None:
         if self.closed:
             return
index c93db790826d88b5365e98bca59959a5daf732e1..058eaf6633e095119ddda345fa564e8458c89f2a 100644 (file)
@@ -5,7 +5,7 @@ Functions to manipulate conninfo strings
 # Copyright (C) 2020-2021 The Psycopg Team
 
 import re
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional
 from pathlib import Path
 from datetime import tzinfo
 
@@ -49,18 +49,22 @@ def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
     return conninfo
 
 
-def conninfo_to_dict(conninfo: str) -> Dict[str, str]:
+def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]:
     """
     Convert the *conninfo* string into a dictionary of parameters.
 
     Raise ProgrammingError if the string is not valid.
     """
     opts = _parse_conninfo(conninfo)
-    return {
+    rv = {
         opt.keyword.decode("utf8"): opt.val.decode("utf8")
         for opt in opts
         if opt.val is not None
     }
+    for k, v in kwargs.items():
+        if v is not None:
+            rv[k] = v
+    return rv
 
 
 def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]:
@@ -95,22 +99,6 @@ def _param_escape(s: str) -> str:
     return s
 
 
-def _conninfo_connect_timeout(
-    conninfo: str, **kwargs: Any
-) -> Tuple[str, Optional[int]]:
-    """
-    Build 'conninfo' by combining input value with kwargs and extract
-    'connect_timeout' parameter.
-    """
-    conninfo = make_conninfo(conninfo, **kwargs)
-    connect_timeout: Optional[int]
-    try:
-        connect_timeout = int(conninfo_to_dict(conninfo)["connect_timeout"])
-    except KeyError:
-        connect_timeout = None
-    return conninfo, connect_timeout
-
-
 class ConnectionInfo:
     """Allow access to information about the connection."""
 
index c35d6d9d2f5e53f241ee802c5f38e35499447b1a..8171881fe1f997729eace9a0eff3753664c2dd7e 100644 (file)
@@ -11,7 +11,7 @@ from psycopg import encodings
 from psycopg import Connection, Notify
 from psycopg.rows import tuple_row
 from psycopg.errors import UndefinedTable
-from psycopg.conninfo import conninfo_to_dict
+from psycopg.conninfo import conninfo_to_dict, make_conninfo
 
 from .utils import gc_collect
 from .test_cursor import my_row_factory
@@ -690,3 +690,40 @@ def test_set_transaction_param_strange(conn):
 
     conn.deferrable = 0
     assert conn.deferrable is False
+
+
+conninfo_params_timeout = [
+    (
+        "",
+        {"host": "localhost", "connect_timeout": None},
+        ({"host": "localhost"}, None),
+    ),
+    (
+        "",
+        {"host": "localhost", "connect_timeout": 1},
+        ({"host": "localhost", "connect_timeout": "1"}, 1),
+    ),
+    (
+        "dbname=postgres",
+        {},
+        ({"dbname": "postgres"}, None),
+    ),
+    (
+        "dbname=postgres connect_timeout=2",
+        {},
+        ({"dbname": "postgres", "connect_timeout": "2"}, 2),
+    ),
+    (
+        "postgresql:///postgres?connect_timeout=2",
+        {"connect_timeout": 10},
+        ({"dbname": "postgres", "connect_timeout": "10"}, 10),
+    ),
+]
+
+
+@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout)
+def test_get_connection_params(dsn, kwargs, exp):
+    params = Connection._get_connection_params(dsn, **kwargs)
+    conninfo = make_conninfo(**params)
+    assert conninfo_to_dict(conninfo) == exp[0]
+    assert params.get("connect_timeout") == exp[1]
index 8e54f4a10d1ce21ab879444e5f43fa89b3c2edbd..46ccdc43554e2b7353cccbe31bae067a1dc18a84 100644 (file)
@@ -10,11 +10,11 @@ from psycopg import encodings
 from psycopg import AsyncConnection, Notify
 from psycopg.rows import tuple_row
 from psycopg.errors import UndefinedTable
-from psycopg.conninfo import conninfo_to_dict
+from psycopg.conninfo import conninfo_to_dict, make_conninfo
 
 from .utils import gc_collect
 from .test_cursor import my_row_factory
-from .test_connection import tx_params, tx_values_map
+from .test_connection import tx_params, tx_values_map, conninfo_params_timeout
 
 pytestmark = pytest.mark.asyncio
 
@@ -696,3 +696,11 @@ async def test_set_transaction_param_strange(aconn):
 
     await aconn.set_deferrable(0)
     assert aconn.deferrable is False
+
+
+@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout)
+async def test_get_connection_params(dsn, kwargs, exp):
+    params = await AsyncConnection._get_connection_params(dsn, **kwargs)
+    conninfo = make_conninfo(**params)
+    assert conninfo_to_dict(conninfo) == exp[0]
+    assert params["connect_timeout"] == exp[1]
index 611d72a57f88fbc6f37ea1acf84ea2d87f9a00d2..c36f75a7e1dea3b899cd092a5523c610afde33bc 100644 (file)
@@ -5,12 +5,7 @@ import pytest
 
 import psycopg
 from psycopg import ProgrammingError
-from psycopg.conninfo import (
-    _conninfo_connect_timeout,
-    make_conninfo,
-    conninfo_to_dict,
-    ConnectionInfo,
-)
+from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
 
 snowman = "\u2603"
 
@@ -95,37 +90,6 @@ def test_no_munging():
     assert dsnin == dsnout
 
 
-@pytest.mark.parametrize(
-    "dsn, kwargs, exp",
-    [
-        (
-            "",
-            {"host": "localhost", "connect_timeout": 1},
-            ({"host": "localhost", "connect_timeout": "1"}, 1),
-        ),
-        (
-            "dbname=postgres",
-            {},
-            ({"dbname": "postgres"}, None),
-        ),
-        (
-            "dbname=postgres connect_timeout=2",
-            {},
-            ({"dbname": "postgres", "connect_timeout": "2"}, 2),
-        ),
-        (
-            "postgresql:///postgres?connect_timeout=2",
-            {"connect_timeout": 10},
-            ({"dbname": "postgres", "connect_timeout": "10"}, 10),
-        ),
-    ],
-)
-def test__conninfo_connect_timeout(dsn, kwargs, exp):
-    conninfo, connect_timeout = _conninfo_connect_timeout(dsn, **kwargs)
-    assert conninfo_to_dict(conninfo) == exp[0]
-    assert connect_timeout == exp[1]
-
-
 class TestConnectionInfo:
     @pytest.mark.parametrize(
         "attr",