From 32724db5bdfb90fb8f54a13532f2b07a8390906f Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Mon, 30 Aug 2021 00:01:15 +0200 Subject: [PATCH] Allow to intermix local addresses with DNS async resolution --- psycopg/psycopg/_dns.py | 22 +++++++++++++++------- tests/test_dns.py | 5 +++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py index f4ac6ffe6..c77f294d1 100644 --- a/psycopg/psycopg/_dns.py +++ b/psycopg/psycopg/_dns.py @@ -52,24 +52,24 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: .. warning:: This function currently doesn't handle the ``/etc/hosts`` file. """ - host_arg: str = params.get("host", os.environ.get("PGHOST", "")) hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", "")) - - if hostaddr_arg or not host_arg: + if hostaddr_arg: + # Already resolved return params - port_arg: str = str(params.get("port", os.environ.get("PGPORT", ""))) - if pq.version() < 100000: # hostaddr not supported return params - if host_arg.startswith("/") or host_arg[1:2] == ":": - # Local path + host_arg: str = params.get("host", os.environ.get("PGHOST", "")) + if not host_arg: + # Nothing to resolve return params hosts_in = host_arg.split(",") + port_arg: str = str(params.get("port", os.environ.get("PGPORT", ""))) ports_in = port_arg.split(",") + if len(ports_in) == 1: # If only one port is specified, the libpq will apply it to all # the hosts, so don't mangle it. @@ -87,6 +87,14 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: hosts_out = [] hostaddr_out = [] for i, host in enumerate(hosts_in): + if not host or host.startswith("/") or host[1:2] == ":": + # Local path + hosts_out.append(host) + hostaddr_out.append("") + if ports_in: + ports_out.append(ports_in[i]) + continue + # If the host is already an ip address don't try to resolve it if is_ip_address(host): hosts_out.append(host) diff --git a/tests/test_dns.py b/tests/test_dns.py index 5835fbc19..d76fe60db 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -71,6 +71,11 @@ async def test_resolve_hostaddr_async_no_resolve( "host=foo.com hostaddr=1.1.1.1", None, ), + ( + "host=foo.com, port=5432,5433", + "host=foo.com, hostaddr=1.1.1.1, port=5432,5433", + None, + ), ( "host=nosuchhost.com,foo.com", "host=foo.com hostaddr=1.1.1.1", -- 2.47.3