`~zoneinfo.ZoneInfo` (ambiguous offset, see :ticket:`#652`).
- Handle gracefully EINTR on signals instead of raising `InterruptedError`,
consistently with :pep:`475` guideline (:ticket:`#667`).
+- Fix support for connection strings with multiple hosts (:ticket:`#674`).
Current release
import logging
import threading
from types import TracebackType
-from typing import Any, Callable, cast, Dict, Generator, Generic, Iterator
+from typing import Any, Callable, cast, Generator, Generic, Iterator
from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union
from typing import overload, TYPE_CHECKING
from weakref import ref, ReferenceType
from ._compat import LiteralString
from .pq.misc import connection_summary
from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
+from .conninfo import conninfo_attempts, ConnDict
from ._pipeline import BasePipeline, Pipeline
from .generators import notifies, connect, execute
from ._encodings import pgconn_encoding
ConnStatus = pq.ConnStatus
TransactionStatus = pq.TransactionStatus
+ # Default timeout for connection a attempt.
+ # Arbitrary timeout, what applied by the libpq on my computer.
+ # Your mileage won't vary.
+ _DEFAULT_CONNECT_TIMEOUT = 130
+
def __init__(self, pgconn: "PGconn"):
self.pgconn = pgconn
self._autocommit = False
Connect to a database server and return a new `Connection` instance.
"""
params = cls._get_connection_params(conninfo, **kwargs)
- conninfo = make_conninfo(**params)
+ timeout = int(params["connect_timeout"])
+ rv = None
+ for attempt in conninfo_attempts(params):
+ try:
+ conninfo = make_conninfo(**attempt)
+ rv = cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
+ break
+ except e._NO_TRACEBACK as ex:
+ last_ex = ex
- try:
- rv = cls._wait_conn(
- cls._connect_gen(conninfo), timeout=params["connect_timeout"]
- )
- except e._NO_TRACEBACK as ex:
- raise ex.with_traceback(None)
+ if not rv:
+ assert last_ex
+ raise last_ex.with_traceback(None)
rv._autocommit = bool(autocommit)
if row_factory:
self.close()
@classmethod
- def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> Dict[str, Any]:
+ def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
"""Manipulate connection parameters before connecting.
:param conninfo: Connection string as received by `~Connection.connect()`.
if "connect_timeout" in params:
params["connect_timeout"] = int(params["connect_timeout"])
else:
- params["connect_timeout"] = None
+ # The sync connect function will stop on the default socket timeout
+ # Because in async connection mode we need to enforce the timeout
+ # ourselves, we need a finite value.
+ params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT
return params
import asyncio
import logging
from types import TracebackType
-from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional
+from typing import Any, AsyncGenerator, AsyncIterator, List, Optional
from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
from contextlib import asynccontextmanager
from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
-from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async
+from .conninfo import ConnDict, make_conninfo, conninfo_to_dict, conninfo_attempts_async
from ._pipeline import AsyncPipeline
from ._encodings import pgconn_encoding
from .connection import BaseConnection, CursorRow, Notify
)
params = await cls._get_connection_params(conninfo, **kwargs)
- conninfo = make_conninfo(**params)
+ timeout = int(params["connect_timeout"])
+ rv = None
+ async for attempt in conninfo_attempts_async(params):
+ try:
+ conninfo = make_conninfo(**attempt)
+ rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
+ break
+ except e._NO_TRACEBACK as ex:
+ last_ex = ex
- try:
- rv = await cls._wait_conn(
- cls._connect_gen(conninfo), timeout=params["connect_timeout"]
- )
- except e._NO_TRACEBACK as ex:
- raise ex.with_traceback(None)
+ if not rv:
+ assert last_ex
+ raise last_ex.with_traceback(None)
rv._autocommit = bool(autocommit)
if row_factory:
await self.close()
@classmethod
- async def _get_connection_params(
- cls, conninfo: str, **kwargs: Any
- ) -> Dict[str, Any]:
- """Manipulate connection parameters before connecting.
-
- .. versionchanged:: 3.1
- Unlike the sync counterpart, perform non-blocking address
- resolution and populate the ``hostaddr`` connection parameter,
- unless the user has provided one themselves. See
- `~psycopg._dns.resolve_hostaddr_async()` for details.
-
- """
+ async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
+ """Manipulate connection parameters before connecting."""
params = conninfo_to_dict(conninfo, **kwargs)
# Make sure there is an usable connect_timeout
if "connect_timeout" in params:
params["connect_timeout"] = int(params["connect_timeout"])
else:
- params["connect_timeout"] = None
-
- # Resolve host addresses in non-blocking way
- params = await resolve_hostaddr_async(params)
+ # The sync connect function will stop on the default socket timeout
+ # Because in async connection mode we need to enforce the timeout
+ # ourselves, we need a finite value.
+ params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT
return params
return out
+def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
+ """Split a set of connection params on the single attempts to perforn.
+
+ A connection param can perform more than one attempt more than one ``host``
+ is provided.
+
+ Because the libpq async function doesn't honour the timeout, we need to
+ reimplement the repeated attempts.
+ """
+ for attempt in _split_attempts(_inject_defaults(params)):
+ yield attempt
+
+
+async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
+ """Split a set of connection params on the single attempts to perforn.
+
+ A connection param can perform more than one attempt more than one ``host``
+ is provided.
+
+ Also perform async resolution of the hostname into hostaddr in order to
+ avoid blocking. Because a host can resolve to more than one address, this
+ can lead to yield more attempts too. Raise `OperationalError` if no host
+ could be resolved.
+
+ Because the libpq async function doesn't honour the timeout, we need to
+ reimplement the repeated attempts.
+ """
+ yielded = False
+ last_exc = None
+ for attempt in _split_attempts(_inject_defaults(params)):
+ try:
+ async for a2 in _split_attempts_and_resolve(attempt):
+ yielded = True
+ yield a2
+ except OSError as ex:
+ last_exc = ex
+
+ if not yielded:
+ assert last_exc
+ # We couldn't resolve anything
+ raise e.OperationalError(str(last_exc))
+
+
def _inject_defaults(params: ConnDict) -> ConnDict:
"""
Add defaults to a dictionary of parameters.
# Get server params
host = cdict.get("host") or os.environ.get("PGHOST")
self.server_host = host if host and not host.startswith("/") else "localhost"
- self.server_port = cdict.get("port", "5432")
+ self.server_port = cdict.get("port") or os.environ.get("PGPORT", "5432")
# Get client params
self.client_host = "localhost"
from .test_cursor import my_row_factory
from .test_adapt import make_bin_dumper, make_dumper
+DEFAULT_TIMEOUT = psycopg.Connection._DEFAULT_CONNECT_TIMEOUT
+
def test_connect(conn_cls, dsn):
conn = conn_cls.connect(dsn)
(("host=foo user=bar",), {}, "host=foo user=bar"),
(("host=foo",), {"user": "baz"}, "host=foo user=baz"),
(
- ("host=foo port=5432",),
- {"host": "qux", "user": "joe"},
- "host=qux user=joe port=5432",
+ ("dbname=foo port=5433",),
+ {"dbname": "qux", "user": "joe"},
+ "dbname=qux user=joe port=5433",
),
(("host=foo",), {"user": None}, "host=foo"),
],
)
-def test_connect_args(conn_cls, monkeypatch, pgconn, args, kwargs, want):
- the_conninfo: str
+def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want):
+ got_conninfo: str
def fake_connect(conninfo):
- nonlocal the_conninfo
- the_conninfo = conninfo
+ nonlocal got_conninfo
+ got_conninfo = conninfo
return pgconn
yield
monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
conn = conn_cls.connect(*args, **kwargs)
- assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+ got_params = drop_default_args_from_conninfo(got_conninfo)
+ assert got_params == conninfo_to_dict(want)
conn.close()
(
"",
{"dbname": "mydb", "connect_timeout": None},
- ({"dbname": "mydb"}, None),
+ ({"dbname": "mydb"}, DEFAULT_TIMEOUT),
),
(
"",
(
"dbname=postgres",
{},
- ({"dbname": "postgres"}, None),
+ ({"dbname": "postgres"}, DEFAULT_TIMEOUT),
),
(
"dbname=postgres connect_timeout=2",
def test_get_connection_params(conn_cls, dsn, kwargs, exp):
params = conn_cls._get_connection_params(dsn, **kwargs)
conninfo = make_conninfo(**params)
- assert conninfo_to_dict(conninfo) == exp[0]
- assert params.get("connect_timeout") == exp[1]
+ assert drop_default_args_from_conninfo(conninfo) == exp[0]
+ assert params["connect_timeout"] == exp[1]
def test_connect_context(conn_cls, dsn):
def test_cancel_closed(conn):
conn.close()
conn.cancel()
+
+
+def drop_default_args_from_conninfo(conninfo):
+ params = conninfo_to_dict(conninfo)
+
+ def removeif(key, value):
+ if params.get(key) == value:
+ params.pop(key)
+
+ removeif("host", "")
+ removeif("hostaddr", "")
+ removeif("port", "5432")
+ removeif("connect_timeout", str(DEFAULT_TIMEOUT))
+
+ return params
from .utils import gc_collect
from .test_cursor import my_row_factory
from .test_connection import tx_params, tx_params_isolation, tx_values_map
-from .test_connection import conninfo_params_timeout
+from .test_connection import conninfo_params_timeout, drop_default_args_from_conninfo
from .test_connection import testctx # noqa: F401 # fixture
from .test_adapt import make_bin_dumper, make_dumper
from .test_conninfo import fake_resolve # noqa: F401
(("dbname=foo user=bar",), {}, "dbname=foo user=bar"),
(("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"),
(
- ("dbname=foo port=5432",),
+ ("dbname=foo port=5433",),
{"dbname": "qux", "user": "joe"},
- "dbname=qux user=joe port=5432",
+ "dbname=qux user=joe port=5433",
),
(("dbname=foo",), {"user": None}, "dbname=foo"),
],
async def test_connect_args(
aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want
):
- the_conninfo: str
+ got_conninfo: str
def fake_connect(conninfo):
- nonlocal the_conninfo
- the_conninfo = conninfo
+ nonlocal got_conninfo
+ got_conninfo = conninfo
return pgconn
yield
setpgenv({})
monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
conn = await aconn_cls.connect(*args, **kwargs)
- assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+ got_params = drop_default_args_from_conninfo(got_conninfo)
+ assert got_params == conninfo_to_dict(want)
await conn.close()
setpgenv({})
params = await aconn_cls._get_connection_params(dsn, **kwargs)
conninfo = make_conninfo(**params)
- assert conninfo_to_dict(conninfo) == exp[0]
+ assert drop_default_args_from_conninfo(conninfo) == exp[0]
assert params["connect_timeout"] == exp[1]
assert len(got) == 1
want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
- assert conninfo_to_dict(got[0]) == want
+ assert drop_default_args_from_conninfo(got[0]) == want
import psycopg
from psycopg.conninfo import conninfo_to_dict
-pytestmark = [pytest.mark.dns]
+from .test_conninfo import fake_resolve # noqa: F401 # fixture
+from .test_connection import drop_default_args_from_conninfo
+@pytest.mark.usefixtures("fake_resolve")
+async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch):
+ got = []
+
+ def fake_connect_gen(conninfo, **kwargs):
+ got.append(conninfo)
+ 1 / 0
+
+ monkeypatch.setattr(aconn_cls, "_connect_gen", fake_connect_gen)
+
+ with pytest.raises(ZeroDivisionError):
+ await aconn_cls.connect("host=foo.com")
+
+ assert len(got) == 1
+ want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
+ assert drop_default_args_from_conninfo(got[0]) == want
+
+
+@pytest.mark.dns
@pytest.mark.anyio
async def test_resolve_hostaddr_async_warning(recwarn):
import_dnspython()
params = await psycopg._dns.resolve_hostaddr_async( # type: ignore[attr-defined]
params
)
- assert conninfo_to_dict(conninfo) == params
assert "resolve_hostaddr_async" in str(recwarn.pop(DeprecationWarning).message)
import pytest
from psycopg._cmodule import _psycopg
+from psycopg.conninfo import conninfo_to_dict
+
+from .test_connection import drop_default_args_from_conninfo
@pytest.mark.parametrize(
- "args, kwargs, want_conninfo",
+ "args, kwargs, want",
[
((), {}, ""),
(("dbname=foo",), {"user": "bar"}, "dbname=foo user=bar"),
((), {"user": "foo", "dbname": None}, "user=foo"),
],
)
-def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo):
+def test_connect(monkeypatch, dsn, args, kwargs, want):
# Check the main args passing from psycopg.connect to the conn generator
# Details of the params manipulation are in test_conninfo.
import psycopg.connection
monkeypatch.setattr(psycopg.connection, "connect", mock_connect)
conn = psycopg.connect(*args, **kwargs)
- assert got_conninfo == want_conninfo
+ assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want)
conn.close()
from . import dbapi20
from . import dbapi20_tpc
+from .test_connection import drop_default_args_from_conninfo
@pytest.fixture(scope="class")
(("host=foo user=bar",), {}, "host=foo user=bar"),
(("host=foo",), {"user": "baz"}, "host=foo user=baz"),
(
- ("host=foo port=5432",),
+ ("host=foo port=5433",),
{"host": "qux", "user": "joe"},
- "host=qux user=joe port=5432",
+ "host=qux user=joe port=5433",
),
(("host=foo",), {"user": None}, "host=foo"),
],
)
def test_connect_args(monkeypatch, pgconn, args, kwargs, want):
- the_conninfo: str
+ got_conninfo: str
def fake_connect(conninfo):
- nonlocal the_conninfo
- the_conninfo = conninfo
+ nonlocal got_conninfo
+ got_conninfo = conninfo
return pgconn
yield
monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
conn = psycopg.connect(*args, **kwargs)
- assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+ assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want)
conn.close()