]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: drop dnspython implementation of resolve_hostaddr_async
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 3 Jul 2022 01:44:41 +0000 (02:44 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 3 Jul 2022 01:44:41 +0000 (02:44 +0100)
Expose the asyncio function in the _dns module, but mark it deprecated,
because the connection perform async resolution on its own.

psycopg/psycopg/_dns.py
tests/test_dns.py

index e7e5ddb59d30d7646541a05d2f9f0653b20b8159..ef8574ed780d0e7720b1ef7262641ae4ad34052b 100644 (file)
@@ -7,6 +7,7 @@ DNS query support
 
 import os
 import re
+import warnings
 from random import randint
 from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence
 from typing import TYPE_CHECKING
@@ -22,7 +23,7 @@ except ImportError:
     )
 
 from . import errors as e
-from .conninfo import is_ip_address
+from .conninfo import resolve_hostaddr_async as resolve_hostaddr_async_
 
 if TYPE_CHECKING:
     from dns.rdtypes.IN.SRV import SRV
@@ -63,85 +64,13 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
            #LIBPQ-PARAMKEYWORDS
 
     .. warning::
-        This function currently doesn't handle the ``/etc/hosts`` file.
+        Before psycopg 3.1, this function doesn't handle the ``/etc/hosts`` file.
     """
-    hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", ""))
-    if hostaddr_arg:
-        # Already resolved
-        return params
-
-    host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
-    if not host_arg:
-        # Nothing to resolve
-        return params
-
-    hosts_in = host_arg.split(",")
-    port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
-    ports_in = port_arg.split(",")
-
-    if len(ports_in) == 1:
-        # If only one port is specified, the libpq will apply it to all
-        # the hosts, so don't mangle it.
-        del ports_in[:]
-    elif len(ports_in) > 1:
-        if len(ports_in) != len(hosts_in):
-            # ProgrammingError would have been more appropriate, but this is
-            # what the raise if the libpq fails connect in the same case.
-            raise e.OperationalError(
-                f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
-            )
-        ports_out = []
-
-    hosts_out = []
-    hostaddr_out = []
-    for i, host in enumerate(hosts_in):
-        if not host or host.startswith("/") or host[1:2] == ":":
-            # Local path
-            hosts_out.append(host)
-            hostaddr_out.append("")
-            if ports_in:
-                ports_out.append(ports_in[i])
-            continue
-
-        # If the host is already an ip address don't try to resolve it
-        if is_ip_address(host):
-            hosts_out.append(host)
-            hostaddr_out.append(host)
-            if ports_in:
-                ports_out.append(ports_in[i])
-            continue
-
-        try:
-            ans = await async_resolver.resolve(host)
-        except DNSException as ex:
-            # Special case localhost: on MacOS it doesn't get resolved.
-            # I assume it is just resolved by /etc/hosts, which is not handled
-            # by dnspython.
-            if host == "localhost":
-                hosts_out.append(host)
-                hostaddr_out.append("127.0.0.1")
-                if ports_in:
-                    ports_out.append(ports_in[i])
-            else:
-                last_exc = ex
-        else:
-            for rdata in ans:
-                hosts_out.append(host)
-                hostaddr_out.append(rdata.address)
-                if ports_in:
-                    ports_out.append(ports_in[i])
-
-    # Throw an exception if no host could be resolved
-    if not hosts_out:
-        raise e.OperationalError(str(last_exc))
-
-    out = params.copy()
-    out["host"] = ",".join(hosts_out)
-    out["hostaddr"] = ",".join(hostaddr_out)
-    if ports_in:
-        out["port"] = ",".join(ports_out)
-
-    return out
+    warnings.warn(
+        "from psycopg 3.1, resolve_hostaddr_async() is not needed anymore",
+        DeprecationWarning,
+    )
+    return await resolve_hostaddr_async_(params)
 
 
 def resolve_srv(params: Dict[str, Any]) -> Dict[str, Any]:
index 66c7085184d9a5639128385ead61ca99210bace0..f50092ffbe5809dbe581a4e17c5401b3a536ece6 100644 (file)
@@ -6,160 +6,16 @@ from psycopg.conninfo import conninfo_to_dict
 pytestmark = [pytest.mark.dns]
 
 
-@pytest.mark.parametrize(
-    "conninfo, want, env",
-    [
-        ("", "", None),
-        ("host='' user=bar", "host='' user=bar", None),
-        (
-            "host=127.0.0.1 user=bar",
-            "host=127.0.0.1 user=bar hostaddr=127.0.0.1",
-            None,
-        ),
-        (
-            "host=1.1.1.1,2.2.2.2 user=bar",
-            "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2",
-            None,
-        ),
-        (
-            "host=1.1.1.1,2.2.2.2 port=5432",
-            "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
-            None,
-        ),
-        (
-            "port=5432",
-            "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
-            {"PGHOST": "1.1.1.1,2.2.2.2"},
-        ),
-        (
-            "host=foo.com port=5432",
-            "host=foo.com port=5432",
-            {"PGHOSTADDR": "1.2.3.4"},
-        ),
-    ],
-)
 @pytest.mark.asyncio
-async def test_resolve_hostaddr_async_no_resolve(
-    monkeypatch, conninfo, want, env, fail_resolve
-):
-    if env:
-        for k, v in env.items():
-            monkeypatch.setenv(k, v)
-    params = conninfo_to_dict(conninfo)
-    params = await psycopg._dns.resolve_hostaddr_async(  # type: ignore[attr-defined]
-        params
-    )
-    assert conninfo_to_dict(want) == params
-
-
-@pytest.mark.parametrize(
-    "conninfo, want, env",
-    [
-        (
-            "host=foo.com,qux.com",
-            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
-            None,
-        ),
-        (
-            "host=foo.com,qux.com port=5433",
-            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433",
-            None,
-        ),
-        (
-            "host=foo.com,qux.com port=5432,5433",
-            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433",
-            None,
-        ),
-        (
-            "host=foo.com,nosuchhost.com",
-            "host=foo.com hostaddr=1.1.1.1",
-            None,
-        ),
-        (
-            "host=foo.com, port=5432,5433",
-            "host=foo.com, hostaddr=1.1.1.1, port=5432,5433",
-            None,
-        ),
-        (
-            "host=nosuchhost.com,foo.com",
-            "host=foo.com hostaddr=1.1.1.1",
-            None,
-        ),
-        (
-            "host=foo.com,qux.com",
-            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
-            {},
-        ),
-    ],
-)
-@pytest.mark.asyncio
-async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve):
+async def test_resolve_hostaddr_async_warning(recwarn):
+    import_dnspython()
+    conninfo = "dbname=foo"
     params = conninfo_to_dict(conninfo)
     params = await psycopg._dns.resolve_hostaddr_async(  # type: ignore[attr-defined]
         params
     )
-    assert conninfo_to_dict(want) == params
-
-
-@pytest.mark.parametrize(
-    "conninfo, env",
-    [
-        ("host=bad1.com,bad2.com", None),
-        ("host=foo.com port=1,2", None),
-        ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
-        ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
-    ],
-)
-@pytest.mark.asyncio
-async def test_resolve_hostaddr_async_bad(monkeypatch, conninfo, env, fake_resolve):
-    if env:
-        for k, v in env.items():
-            monkeypatch.setenv(k, v)
-    params = conninfo_to_dict(conninfo)
-    with pytest.raises(psycopg.Error):
-        await psycopg._dns.resolve_hostaddr_async(params)  # type: ignore[attr-defined]
-
-
-@pytest.fixture
-def fake_resolve(monkeypatch):
-    import_dnspython()
-
-    import dns.rdtypes.IN.A
-    from dns.exception import DNSException
-
-    fake_hosts = {
-        "localhost": "127.0.0.1",
-        "foo.com": "1.1.1.1",
-        "qux.com": "2.2.2.2",
-    }
-
-    async def fake_resolve_(qname):
-        try:
-            addr = fake_hosts[qname]
-        except KeyError:
-            raise DNSException(f"unknown test host: {qname}")
-        else:
-            return [dns.rdtypes.IN.A.A("IN", "A", addr)]
-
-    monkeypatch.setattr(
-        psycopg._dns.async_resolver,  # type: ignore[attr-defined]
-        "resolve",
-        fake_resolve_,
-    )
-
-
-@pytest.fixture
-def fail_resolve(monkeypatch):
-    import_dnspython()
-
-    async def fail_resolve_(qname):
-        pytest.fail(f"shouldn't try to resolve {qname}")
-
-    monkeypatch.setattr(
-        psycopg._dns.async_resolver,  # type: ignore[attr-defined]
-        "resolve",
-        fail_resolve_,
-    )
+    assert conninfo_to_dict(conninfo) == params
+    assert "resolve_hostaddr_async" in str(recwarn.pop(DeprecationWarning).message)
 
 
 def import_dnspython():