]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Allow to intermix local addresses with DNS async resolution
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 29 Aug 2021 22:01:15 +0000 (00:01 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 30 Aug 2021 03:33:29 +0000 (05:33 +0200)
psycopg/psycopg/_dns.py
tests/test_dns.py

index f4ac6ffe62f9cd510ca244a03e5317d182a5e529..c77f294d1a7a402cda6e4b4038680244b4a97a8d 100644 (file)
@@ -52,24 +52,24 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
     .. warning::
         This function currently doesn't handle the ``/etc/hosts`` file.
     """
-    host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
     hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", ""))
-
-    if hostaddr_arg or not host_arg:
+    if hostaddr_arg:
+        # Already resolved
         return params
 
-    port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
-
     if pq.version() < 100000:
         # hostaddr not supported
         return params
 
-    if host_arg.startswith("/") or host_arg[1:2] == ":":
-        # Local path
+    host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
+    if not host_arg:
+        # Nothing to resolve
         return params
 
     hosts_in = host_arg.split(",")
+    port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
     ports_in = port_arg.split(",")
+
     if len(ports_in) == 1:
         # If only one port is specified, the libpq will apply it to all
         # the hosts, so don't mangle it.
@@ -87,6 +87,14 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
     hosts_out = []
     hostaddr_out = []
     for i, host in enumerate(hosts_in):
+        if not host or host.startswith("/") or host[1:2] == ":":
+            # Local path
+            hosts_out.append(host)
+            hostaddr_out.append("")
+            if ports_in:
+                ports_out.append(ports_in[i])
+            continue
+
         # If the host is already an ip address don't try to resolve it
         if is_ip_address(host):
             hosts_out.append(host)
index 5835fbc196d3dc126da904147bb0c152ed228a1b..d76fe60db5afaef69175aab2f3e2d51487075a9e 100644 (file)
@@ -71,6 +71,11 @@ async def test_resolve_hostaddr_async_no_resolve(
             "host=foo.com hostaddr=1.1.1.1",
             None,
         ),
+        (
+            "host=foo.com, port=5432,5433",
+            "host=foo.com, hostaddr=1.1.1.1, port=5432,5433",
+            None,
+        ),
         (
             "host=nosuchhost.com,foo.com",
             "host=foo.com hostaddr=1.1.1.1",