]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: don't block on address resolution in async connections
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 3 Jul 2022 00:25:19 +0000 (01:25 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 3 Jul 2022 00:25:19 +0000 (01:25 +0100)
Use the same algorithm implemented in the `_dns` module, but based on
asyncio getaddrinfo: this avoids the need of the external dns package
and works correctly with the /etc/hosts file.

There were problems of resource leaking in Python 3.6, but as psycopg
3.1 is 3.7+ only, let's go for it!

psycopg/psycopg/_dns.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/conninfo.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_conninfo.py
tests/test_dns.py

index 710113843c82b474571be9e5dfcfd054799c5ebd..cab164012bb7bd59833df89d06e506596dee8050 100644 (file)
@@ -10,8 +10,6 @@ import re
 from random import randint
 from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence
 from typing import TYPE_CHECKING
-from functools import lru_cache
-from ipaddress import ip_address
 from collections import defaultdict
 
 try:
@@ -24,6 +22,7 @@ except ImportError:
     )
 
 from . import errors as e
+from .conninfo import is_ip_address
 
 if TYPE_CHECKING:
     from dns.rdtypes.IN.SRV import SRV
@@ -327,13 +326,3 @@ class Rfc2782Resolver:
                 del entries[i]
 
         return out
-
-
-@lru_cache()
-def is_ip_address(s: str) -> bool:
-    """Return True if the string represent a valid ip address."""
-    try:
-        ip_address(s)
-    except ValueError:
-        return False
-    return True
index 301ffac9ca76911d1a4000845a9cde9ca719d529..8c6571d0c4aa128a651118eb844ec7ef8f9bc36e 100644 (file)
@@ -20,7 +20,7 @@ from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from .conninfo import make_conninfo, conninfo_to_dict
+from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async
 from ._pipeline import AsyncPipeline
 from ._encodings import pgconn_encoding
 from .connection import BaseConnection, CursorRow, Notify
@@ -181,6 +181,9 @@ class AsyncConnection(BaseConnection[Row]):
         else:
             params["connect_timeout"] = None
 
+        # Resolve host addresses in non-blocking way
+        params = await resolve_hostaddr_async(params)
+
         return params
 
     async def close(self) -> None:
index 767952bca05d14761572dfdf9103b0513113349f..1bc85ad5cd213039f6694d2e2a4cdad60adb2ac7 100644 (file)
@@ -4,10 +4,15 @@ Functions to manipulate conninfo strings
 
 # Copyright (C) 2020 The Psycopg Team
 
+import os
 import re
+import socket
+import asyncio
 from typing import Any, Dict, List, Optional
 from pathlib import Path
 from datetime import tzinfo
+from functools import lru_cache
+from ipaddress import ip_address
 
 from . import pq
 from . import errors as e
@@ -263,3 +268,106 @@ class ConnectionInfo:
     def _get_pgconn_attr(self, name: str) -> str:
         value: bytes = getattr(self.pgconn, name)
         return value.decode(self.encoding)
+
+
+async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
+    """
+    Perform async DNS lookup of the hosts and return a new params dict.
+
+    :param params: The input parameters, for instance as returned by
+        `~psycopg.conninfo.conninfo_to_dict()`.
+
+    If a ``host`` param is present but not ``hostname``, resolve the host
+    addresses dynamically.
+
+    The function may change the input ``host``, ``hostname``, ``port`` to allow
+    connecting without further DNS lookups, eventually removing hosts that are
+    not resolved, keeping the lists of hosts and ports consistent.
+
+    Raise `~psycopg.OperationalError` if connection is not possible (e.g. no
+    host resolve, inconsistent lists length).
+    """
+    hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", ""))
+    if hostaddr_arg:
+        # Already resolved
+        return params
+
+    host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
+    if not host_arg:
+        # Nothing to resolve
+        return params
+
+    hosts_in = host_arg.split(",")
+    port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
+    ports_in = port_arg.split(",")
+    default_port = "5432"
+
+    if len(ports_in) == 1:
+        # If only one port is specified, the libpq will apply it to all
+        # the hosts, so don't mangle it.
+        default_port = ports_in.pop()
+
+    elif len(ports_in) > 1:
+        if len(ports_in) != len(hosts_in):
+            # ProgrammingError would have been more appropriate, but this is
+            # what the raise if the libpq fails connect in the same case.
+            raise e.OperationalError(
+                f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
+            )
+        ports_out = []
+
+    hosts_out = []
+    hostaddr_out = []
+    loop = asyncio.get_running_loop()
+    for i, host in enumerate(hosts_in):
+        if not host or host.startswith("/") or host[1:2] == ":":
+            # Local path
+            hosts_out.append(host)
+            hostaddr_out.append("")
+            if ports_in:
+                ports_out.append(ports_in[i])
+            continue
+
+        # If the host is already an ip address don't try to resolve it
+        if is_ip_address(host):
+            hosts_out.append(host)
+            hostaddr_out.append(host)
+            if ports_in:
+                ports_out.append(ports_in[i])
+            continue
+
+        try:
+            port = ports_in[i] if ports_in else default_port
+            ans = await loop.getaddrinfo(
+                host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+            )
+        except OSError as ex:
+            last_exc = ex
+        else:
+            for item in ans:
+                hosts_out.append(host)
+                hostaddr_out.append(item[4][0])
+                if ports_in:
+                    ports_out.append(ports_in[i])
+
+    # Throw an exception if no host could be resolved
+    if not hosts_out:
+        raise e.OperationalError(str(last_exc))
+
+    out = params.copy()
+    out["host"] = ",".join(hosts_out)
+    out["hostaddr"] = ",".join(hostaddr_out)
+    if ports_in:
+        out["port"] = ",".join(ports_out)
+
+    return out
+
+
+@lru_cache()
+def is_ip_address(s: str) -> bool:
+    """Return True if the string represent a valid ip address."""
+    try:
+        ip_address(s)
+    except ValueError:
+        return False
+    return True
index 58c6556343116b789bdb5f85c841ea615f50dbd1..9922e4adf14b5ef229d12cb1fb62df9ba17ffa08 100644 (file)
@@ -706,13 +706,13 @@ def test_set_transaction_param_strange(conn):
 conninfo_params_timeout = [
     (
         "",
-        {"host": "localhost", "connect_timeout": None},
-        ({"host": "localhost"}, None),
+        {"dbname": "mydb", "connect_timeout": None},
+        ({"dbname": "mydb"}, None),
     ),
     (
         "",
-        {"host": "localhost", "connect_timeout": 1},
-        ({"host": "localhost", "connect_timeout": "1"}, 1),
+        {"dbname": "mydb", "connect_timeout": 1},
+        ({"dbname": "mydb", "connect_timeout": "1"}, 1),
     ),
     (
         "dbname=postgres",
index 93ee762f5f2d2d241b99d0c962d4cc420c598568..a0b4b9935f18944d13e7bd1705862ad1f98c0a8a 100644 (file)
@@ -12,6 +12,7 @@ from .utils import gc_collect
 from .test_cursor import my_row_factory
 from .test_connection import tx_params, tx_values_map, conninfo_params_timeout
 from .test_adapt import make_bin_dumper, make_dumper
+from .test_conninfo import fake_resolve  # noqa: F401
 
 pytestmark = pytest.mark.asyncio
 
@@ -350,14 +351,14 @@ async def test_autocommit_unknown(aconn):
     [
         ((), {}, ""),
         (("",), {}, ""),
-        (("host=foo user=bar",), {}, "host=foo user=bar"),
-        (("host=foo",), {"user": "baz"}, "host=foo user=baz"),
+        (("dbname=foo user=bar",), {}, "dbname=foo user=bar"),
+        (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"),
         (
-            ("host=foo port=5432",),
-            {"host": "qux", "user": "joe"},
-            "host=qux user=joe port=5432",
+            ("dbname=foo port=5432",),
+            {"dbname": "qux", "user": "joe"},
+            "dbname=qux user=joe port=5432",
         ),
-        (("host=foo",), {"user": None}, "host=foo"),
+        (("dbname=foo",), {"user": None}, "dbname=foo"),
     ],
 )
 async def test_connect_args(monkeypatch, pgconn, args, kwargs, want):
@@ -721,3 +722,20 @@ async def test_connect_context_copy(dsn, aconn):
 async def test_cancel_closed(aconn):
     await aconn.close()
     aconn.cancel()
+
+
+async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve):  # noqa: F811
+    got = []
+
+    def fake_connect_gen(conninfo, **kwargs):
+        got.append(conninfo)
+        1 / 0
+
+    monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen)
+
+    with pytest.raises(ZeroDivisionError):
+        await psycopg.AsyncConnection.connect("host=foo.com")
+
+    assert len(got) == 1
+    want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
+    assert conninfo_to_dict(got[0]) == want
index a661fc604673e7b9472007ea2610f1f8256cafe9..d637ecbdcdac18d4ee9ece72fdee76c253ed89ae 100644 (file)
@@ -1,3 +1,5 @@
+import socket
+import asyncio
 import datetime as dt
 
 import pytest
@@ -5,6 +7,7 @@ import pytest
 import psycopg
 from psycopg import ProgrammingError
 from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
+from psycopg.conninfo import resolve_hostaddr_async
 from psycopg._encodings import pg2pyenc
 
 snowman = "\u2603"
@@ -298,3 +301,140 @@ class TestConnectionInfo:
         cur.execute("set client_encoding to EUC_TW")
         with pytest.raises(psycopg.NotSupportedError):
             cur.execute("select 'x'")
+
+
+@pytest.mark.parametrize(
+    "conninfo, want, env",
+    [
+        ("", "", None),
+        ("host='' user=bar", "host='' user=bar", None),
+        (
+            "host=127.0.0.1 user=bar",
+            "host=127.0.0.1 user=bar hostaddr=127.0.0.1",
+            None,
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 user=bar",
+            "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2",
+            None,
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 port=5432",
+            "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+            None,
+        ),
+        (
+            "port=5432",
+            "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+            {"PGHOST": "1.1.1.1,2.2.2.2"},
+        ),
+        (
+            "host=foo.com port=5432",
+            "host=foo.com port=5432",
+            {"PGHOSTADDR": "1.2.3.4"},
+        ),
+    ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_no_resolve(
+    monkeypatch, conninfo, want, env, fail_resolve
+):
+    if env:
+        for k, v in env.items():
+            monkeypatch.setenv(k, v)
+    params = conninfo_to_dict(conninfo)
+    params = await resolve_hostaddr_async(params)
+    assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.parametrize(
+    "conninfo, want, env",
+    [
+        (
+            "host=foo.com,qux.com",
+            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+            None,
+        ),
+        (
+            "host=foo.com,qux.com port=5433",
+            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433",
+            None,
+        ),
+        (
+            "host=foo.com,qux.com port=5432,5433",
+            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433",
+            None,
+        ),
+        (
+            "host=foo.com,nosuchhost.com",
+            "host=foo.com hostaddr=1.1.1.1",
+            None,
+        ),
+        (
+            "host=foo.com, port=5432,5433",
+            "host=foo.com, hostaddr=1.1.1.1, port=5432,5433",
+            None,
+        ),
+        (
+            "host=nosuchhost.com,foo.com",
+            "host=foo.com hostaddr=1.1.1.1",
+            None,
+        ),
+        (
+            "host=foo.com,qux.com",
+            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+            {},
+        ),
+    ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve):
+    params = conninfo_to_dict(conninfo)
+    params = await resolve_hostaddr_async(params)
+    assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.parametrize(
+    "conninfo, env",
+    [
+        ("host=bad1.com,bad2.com", None),
+        ("host=foo.com port=1,2", None),
+        ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
+        ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
+    ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_bad(monkeypatch, conninfo, env, fake_resolve):
+    if env:
+        for k, v in env.items():
+            monkeypatch.setenv(k, v)
+    params = conninfo_to_dict(conninfo)
+    with pytest.raises(psycopg.Error):
+        await resolve_hostaddr_async(params)
+
+
+@pytest.fixture
+async def fake_resolve(monkeypatch):
+    fake_hosts = {
+        "localhost": "127.0.0.1",
+        "foo.com": "1.1.1.1",
+        "qux.com": "2.2.2.2",
+    }
+
+    async def fake_getaddrinfo(host, port, **kwargs):
+        try:
+            addr = fake_hosts[host]
+        except KeyError:
+            raise OSError(f"unknown test host: {host}")
+        else:
+            return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (addr, 432))]
+
+    monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
+
+
+@pytest.fixture
+async def fail_resolve(monkeypatch):
+    async def fail_getaddrinfo(host, port, **kwargs):
+        pytest.fail(f"shouldn't try to resolve {host}")
+
+    monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fail_getaddrinfo)
index ca8202c219f68471d843f37d1f40ab696b808713..66c7085184d9a5639128385ead61ca99210bace0 100644 (file)
@@ -2,7 +2,6 @@ import pytest
 
 import psycopg
 from psycopg.conninfo import conninfo_to_dict
-from psycopg.rows import Row
 
 pytestmark = [pytest.mark.dns]
 
@@ -121,36 +120,6 @@ async def test_resolve_hostaddr_async_bad(monkeypatch, conninfo, env, fake_resol
         await psycopg._dns.resolve_hostaddr_async(params)  # type: ignore[attr-defined]
 
 
-@pytest.mark.asyncio
-async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve):
-    got = []
-
-    def fake_connect_gen(conninfo, **kwargs):
-        got.append(conninfo)
-        1 / 0
-
-    monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen)
-
-    # TODO: not enabled by default, but should be usable to make a subclass
-    class AsyncDnsConnection(psycopg.AsyncConnection[Row]):
-        @classmethod
-        async def _get_connection_params(cls, conninfo, **kwargs):
-            params = await super()._get_connection_params(conninfo, **kwargs)
-            params = await (
-                psycopg._dns.resolve_hostaddr_async(  # type: ignore[attr-defined]
-                    params
-                )
-            )
-            return params
-
-    with pytest.raises(ZeroDivisionError):
-        await AsyncDnsConnection.connect("host=foo.com")
-
-    assert len(got) == 1
-    want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
-    assert conninfo_to_dict(got[0]) == want
-
-
 @pytest.fixture
 def fake_resolve(monkeypatch):
     import_dnspython()