from random import randint
from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence
from typing import TYPE_CHECKING
-from functools import lru_cache
-from ipaddress import ip_address
from collections import defaultdict
try:
)
from . import errors as e
+from .conninfo import is_ip_address
if TYPE_CHECKING:
from dns.rdtypes.IN.SRV import SRV
del entries[i]
return out
-
-
-@lru_cache()
-def is_ip_address(s: str) -> bool:
- """Return True if the string represent a valid ip address."""
- try:
- ip_address(s)
- except ValueError:
- return False
- return True
from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
-from .conninfo import make_conninfo, conninfo_to_dict
+from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async
from ._pipeline import AsyncPipeline
from ._encodings import pgconn_encoding
from .connection import BaseConnection, CursorRow, Notify
else:
params["connect_timeout"] = None
+ # Resolve host addresses in non-blocking way
+ params = await resolve_hostaddr_async(params)
+
return params
async def close(self) -> None:
# Copyright (C) 2020 The Psycopg Team
+import os
import re
+import socket
+import asyncio
from typing import Any, Dict, List, Optional
from pathlib import Path
from datetime import tzinfo
+from functools import lru_cache
+from ipaddress import ip_address
from . import pq
from . import errors as e
def _get_pgconn_attr(self, name: str) -> str:
value: bytes = getattr(self.pgconn, name)
return value.decode(self.encoding)
+
+
+async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Perform async DNS lookup of the hosts and return a new params dict.
+
+ :param params: The input parameters, for instance as returned by
+ `~psycopg.conninfo.conninfo_to_dict()`.
+
+ If a ``host`` param is present but not ``hostname``, resolve the host
+ addresses dynamically.
+
+ The function may change the input ``host``, ``hostname``, ``port`` to allow
+ connecting without further DNS lookups, eventually removing hosts that are
+ not resolved, keeping the lists of hosts and ports consistent.
+
+ Raise `~psycopg.OperationalError` if connection is not possible (e.g. no
+ host resolve, inconsistent lists length).
+ """
+ hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", ""))
+ if hostaddr_arg:
+ # Already resolved
+ return params
+
+ host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
+ if not host_arg:
+ # Nothing to resolve
+ return params
+
+ hosts_in = host_arg.split(",")
+ port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
+ ports_in = port_arg.split(",")
+ default_port = "5432"
+
+ if len(ports_in) == 1:
+ # If only one port is specified, the libpq will apply it to all
+ # the hosts, so don't mangle it.
+ default_port = ports_in.pop()
+
+ elif len(ports_in) > 1:
+ if len(ports_in) != len(hosts_in):
+ # ProgrammingError would have been more appropriate, but this is
+ # what the raise if the libpq fails connect in the same case.
+ raise e.OperationalError(
+ f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
+ )
+ ports_out = []
+
+ hosts_out = []
+ hostaddr_out = []
+ loop = asyncio.get_running_loop()
+ for i, host in enumerate(hosts_in):
+ if not host or host.startswith("/") or host[1:2] == ":":
+ # Local path
+ hosts_out.append(host)
+ hostaddr_out.append("")
+ if ports_in:
+ ports_out.append(ports_in[i])
+ continue
+
+ # If the host is already an ip address don't try to resolve it
+ if is_ip_address(host):
+ hosts_out.append(host)
+ hostaddr_out.append(host)
+ if ports_in:
+ ports_out.append(ports_in[i])
+ continue
+
+ try:
+ port = ports_in[i] if ports_in else default_port
+ ans = await loop.getaddrinfo(
+ host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+ )
+ except OSError as ex:
+ last_exc = ex
+ else:
+ for item in ans:
+ hosts_out.append(host)
+ hostaddr_out.append(item[4][0])
+ if ports_in:
+ ports_out.append(ports_in[i])
+
+ # Throw an exception if no host could be resolved
+ if not hosts_out:
+ raise e.OperationalError(str(last_exc))
+
+ out = params.copy()
+ out["host"] = ",".join(hosts_out)
+ out["hostaddr"] = ",".join(hostaddr_out)
+ if ports_in:
+ out["port"] = ",".join(ports_out)
+
+ return out
+
+
+@lru_cache()
+def is_ip_address(s: str) -> bool:
+ """Return True if the string represent a valid ip address."""
+ try:
+ ip_address(s)
+ except ValueError:
+ return False
+ return True
conninfo_params_timeout = [
(
"",
- {"host": "localhost", "connect_timeout": None},
- ({"host": "localhost"}, None),
+ {"dbname": "mydb", "connect_timeout": None},
+ ({"dbname": "mydb"}, None),
),
(
"",
- {"host": "localhost", "connect_timeout": 1},
- ({"host": "localhost", "connect_timeout": "1"}, 1),
+ {"dbname": "mydb", "connect_timeout": 1},
+ ({"dbname": "mydb", "connect_timeout": "1"}, 1),
),
(
"dbname=postgres",
from .test_cursor import my_row_factory
from .test_connection import tx_params, tx_values_map, conninfo_params_timeout
from .test_adapt import make_bin_dumper, make_dumper
+from .test_conninfo import fake_resolve # noqa: F401
pytestmark = pytest.mark.asyncio
[
((), {}, ""),
(("",), {}, ""),
- (("host=foo user=bar",), {}, "host=foo user=bar"),
- (("host=foo",), {"user": "baz"}, "host=foo user=baz"),
+ (("dbname=foo user=bar",), {}, "dbname=foo user=bar"),
+ (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"),
(
- ("host=foo port=5432",),
- {"host": "qux", "user": "joe"},
- "host=qux user=joe port=5432",
+ ("dbname=foo port=5432",),
+ {"dbname": "qux", "user": "joe"},
+ "dbname=qux user=joe port=5432",
),
- (("host=foo",), {"user": None}, "host=foo"),
+ (("dbname=foo",), {"user": None}, "dbname=foo"),
],
)
async def test_connect_args(monkeypatch, pgconn, args, kwargs, want):
async def test_cancel_closed(aconn):
await aconn.close()
aconn.cancel()
+
+
+async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve): # noqa: F811
+ got = []
+
+ def fake_connect_gen(conninfo, **kwargs):
+ got.append(conninfo)
+ 1 / 0
+
+ monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen)
+
+ with pytest.raises(ZeroDivisionError):
+ await psycopg.AsyncConnection.connect("host=foo.com")
+
+ assert len(got) == 1
+ want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
+ assert conninfo_to_dict(got[0]) == want
+import socket
+import asyncio
import datetime as dt
import pytest
import psycopg
from psycopg import ProgrammingError
from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
+from psycopg.conninfo import resolve_hostaddr_async
from psycopg._encodings import pg2pyenc
snowman = "\u2603"
cur.execute("set client_encoding to EUC_TW")
with pytest.raises(psycopg.NotSupportedError):
cur.execute("select 'x'")
+
+
+@pytest.mark.parametrize(
+ "conninfo, want, env",
+ [
+ ("", "", None),
+ ("host='' user=bar", "host='' user=bar", None),
+ (
+ "host=127.0.0.1 user=bar",
+ "host=127.0.0.1 user=bar hostaddr=127.0.0.1",
+ None,
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 user=bar",
+ "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2",
+ None,
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 port=5432",
+ "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+ None,
+ ),
+ (
+ "port=5432",
+ "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+ {"PGHOST": "1.1.1.1,2.2.2.2"},
+ ),
+ (
+ "host=foo.com port=5432",
+ "host=foo.com port=5432",
+ {"PGHOSTADDR": "1.2.3.4"},
+ ),
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_no_resolve(
+ monkeypatch, conninfo, want, env, fail_resolve
+):
+ if env:
+ for k, v in env.items():
+ monkeypatch.setenv(k, v)
+ params = conninfo_to_dict(conninfo)
+ params = await resolve_hostaddr_async(params)
+ assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.parametrize(
+ "conninfo, want, env",
+ [
+ (
+ "host=foo.com,qux.com",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+ None,
+ ),
+ (
+ "host=foo.com,qux.com port=5433",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433",
+ None,
+ ),
+ (
+ "host=foo.com,qux.com port=5432,5433",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433",
+ None,
+ ),
+ (
+ "host=foo.com,nosuchhost.com",
+ "host=foo.com hostaddr=1.1.1.1",
+ None,
+ ),
+ (
+ "host=foo.com, port=5432,5433",
+ "host=foo.com, hostaddr=1.1.1.1, port=5432,5433",
+ None,
+ ),
+ (
+ "host=nosuchhost.com,foo.com",
+ "host=foo.com hostaddr=1.1.1.1",
+ None,
+ ),
+ (
+ "host=foo.com,qux.com",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+ {},
+ ),
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve):
+ params = conninfo_to_dict(conninfo)
+ params = await resolve_hostaddr_async(params)
+ assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.parametrize(
+ "conninfo, env",
+ [
+ ("host=bad1.com,bad2.com", None),
+ ("host=foo.com port=1,2", None),
+ ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
+ ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_bad(monkeypatch, conninfo, env, fake_resolve):
+ if env:
+ for k, v in env.items():
+ monkeypatch.setenv(k, v)
+ params = conninfo_to_dict(conninfo)
+ with pytest.raises(psycopg.Error):
+ await resolve_hostaddr_async(params)
+
+
+@pytest.fixture
+async def fake_resolve(monkeypatch):
+ fake_hosts = {
+ "localhost": "127.0.0.1",
+ "foo.com": "1.1.1.1",
+ "qux.com": "2.2.2.2",
+ }
+
+ async def fake_getaddrinfo(host, port, **kwargs):
+ try:
+ addr = fake_hosts[host]
+ except KeyError:
+ raise OSError(f"unknown test host: {host}")
+ else:
+ return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (addr, 432))]
+
+ monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
+
+
+@pytest.fixture
+async def fail_resolve(monkeypatch):
+ async def fail_getaddrinfo(host, port, **kwargs):
+ pytest.fail(f"shouldn't try to resolve {host}")
+
+ monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fail_getaddrinfo)
import psycopg
from psycopg.conninfo import conninfo_to_dict
-from psycopg.rows import Row
pytestmark = [pytest.mark.dns]
await psycopg._dns.resolve_hostaddr_async(params) # type: ignore[attr-defined]
-@pytest.mark.asyncio
-async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve):
- got = []
-
- def fake_connect_gen(conninfo, **kwargs):
- got.append(conninfo)
- 1 / 0
-
- monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen)
-
- # TODO: not enabled by default, but should be usable to make a subclass
- class AsyncDnsConnection(psycopg.AsyncConnection[Row]):
- @classmethod
- async def _get_connection_params(cls, conninfo, **kwargs):
- params = await super()._get_connection_params(conninfo, **kwargs)
- params = await (
- psycopg._dns.resolve_hostaddr_async( # type: ignore[attr-defined]
- params
- )
- )
- return params
-
- with pytest.raises(ZeroDivisionError):
- await AsyncDnsConnection.connect("host=foo.com")
-
- assert len(got) == 1
- want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
- assert conninfo_to_dict(got[0]) == want
-
-
@pytest.fixture
def fake_resolve(monkeypatch):
import_dnspython()