From: Daniele Varrazzo Date: Sun, 29 Aug 2021 16:11:20 +0000 (+0200) Subject: Keep env vars into account in async DNS resolutions X-Git-Tag: 3.0.beta1~5^2~1 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=c5c8e51448a1854450832d4ce7f480c5f0ea6a91;p=thirdparty%2Fpsycopg.git Keep env vars into account in async DNS resolutions --- diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py index ef2028c92..4c7798830 100644 --- a/psycopg/psycopg/_dns.py +++ b/psycopg/psycopg/_dns.py @@ -5,6 +5,7 @@ DNS query support # Copyright (C) 2021 The Psycopg Team +import os from typing import Any, Dict from functools import lru_cache from ipaddress import ip_address @@ -48,26 +49,29 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: .. warning:: This function doesn't handle the ``/etc/hosts`` file. """ - if params.get("hostaddr") or not params.get("host"): + host_arg: str = params.get("host", os.environ.get("PGHOST", "")) + hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", "")) + + if hostaddr_arg or not host_arg: return params + port_arg: str = str(params.get("port", os.environ.get("PGPORT", ""))) + if pq.version() < 100000: # hostaddr not supported return params - host = params["host"] - - if host.startswith("/") or host[1:2] == ":": + if host_arg.startswith("/") or host_arg[1:2] == ":": # Local path return params - hosts_in = host.split(",") - ports_in = str(params["port"]).split(",") if params.get("port") else [] - if len(ports_in) <= 1: + hosts_in = host_arg.split(",") + ports_in = port_arg.split(",") + if len(ports_in) == 1: # If only one port is specified, the libpq will apply it to all # the hosts, so don't mangle it. del ports_in[:] - else: + elif len(ports_in) > 1: if len(ports_in) != len(hosts_in): # ProgrammingError would have been more appropriate, but this is # what the raise if the libpq fails connect in the same case. diff --git a/tests/test_dns.py b/tests/test_dns.py index 7ecd6e375..5835fbc19 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -5,73 +5,107 @@ from psycopg.conninfo import conninfo_to_dict @pytest.mark.parametrize( - "conninfo, want", + "conninfo, want, env", [ - ("", ""), - ("host='' user=bar", "host='' user=bar"), + ("", "", None), + ("host='' user=bar", "host='' user=bar", None), ( "host=127.0.0.1 user=bar", "host=127.0.0.1 user=bar hostaddr=127.0.0.1", + None, ), ( "host=1.1.1.1,2.2.2.2 user=bar", "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2", + None, ), ( "host=1.1.1.1,2.2.2.2 port=5432", "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2", + None, + ), + ( + "port=5432", + "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2", + {"PGHOST": "1.1.1.1,2.2.2.2"}, + ), + ( + "host=foo.com port=5432", + "host=foo.com port=5432", + {"PGHOSTADDR": "1.2.3.4"}, ), ], ) @pytest.mark.asyncio -async def test_resolve_hostaddr_async_no_resolve(conninfo, want, fail_resolve): +async def test_resolve_hostaddr_async_no_resolve( + monkeypatch, conninfo, want, env, fail_resolve +): + if env: + for k, v in env.items(): + monkeypatch.setenv(k, v) params = conninfo_to_dict(conninfo) params = await psycopg._dns.resolve_hostaddr_async(params) assert conninfo_to_dict(want) == params @pytest.mark.parametrize( - "conninfo, want", + "conninfo, want, env", [ ( "host=foo.com,qux.com", "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2", + None, ), ( "host=foo.com,qux.com port=5433", "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433", + None, ), ( "host=foo.com,qux.com port=5432,5433", "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433", + None, ), ( "host=foo.com,nosuchhost.com", "host=foo.com hostaddr=1.1.1.1", + None, ), ( "host=nosuchhost.com,foo.com", "host=foo.com hostaddr=1.1.1.1", + None, + ), + ( + "host=foo.com,qux.com", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2", + {}, ), ], ) @pytest.mark.asyncio -async def test_resolve_hostaddr_async(conninfo, want, fake_resolve): +async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve): params = conninfo_to_dict(conninfo) params = await psycopg._dns.resolve_hostaddr_async(params) assert conninfo_to_dict(want) == params @pytest.mark.parametrize( - "conninfo", + "conninfo, env", [ - "host=bad1.com,bad2.com", - "host=foo.com port=1,2", - "host=1.1.1.1,2.2.2.2 port=5432,5433,5434", + ("host=bad1.com,bad2.com", None), + ("host=foo.com port=1,2", None), + ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None), + ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}), ], ) @pytest.mark.asyncio -async def test_resolve_hostaddr_async_bad(conninfo, fake_resolve): +async def test_resolve_hostaddr_async_bad( + monkeypatch, conninfo, env, fake_resolve +): + if env: + for k, v in env.items(): + monkeypatch.setenv(k, v) params = conninfo_to_dict(conninfo) with pytest.raises((TypeError, psycopg.Error)): await psycopg._dns.resolve_hostaddr_async(params)