]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Keep env vars into account in async DNS resolutions
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 29 Aug 2021 16:11:20 +0000 (18:11 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 29 Aug 2021 17:33:20 +0000 (19:33 +0200)
psycopg/psycopg/_dns.py
tests/test_dns.py

index ef2028c92e98c5805d2cb7036b3a8accf01ad902..4c7798830ea442c0f94ad6419bcdf50ad0783340 100644 (file)
@@ -5,6 +5,7 @@ DNS query support
 
 # Copyright (C) 2021 The Psycopg Team
 
+import os
 from typing import Any, Dict
 from functools import lru_cache
 from ipaddress import ip_address
@@ -48,26 +49,29 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
     .. warning::
         This function doesn't handle the ``/etc/hosts`` file.
     """
-    if params.get("hostaddr") or not params.get("host"):
+    host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
+    hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", ""))
+
+    if hostaddr_arg or not host_arg:
         return params
 
+    port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
+
     if pq.version() < 100000:
         # hostaddr not supported
         return params
 
-    host = params["host"]
-
-    if host.startswith("/") or host[1:2] == ":":
+    if host_arg.startswith("/") or host_arg[1:2] == ":":
         # Local path
         return params
 
-    hosts_in = host.split(",")
-    ports_in = str(params["port"]).split(",") if params.get("port") else []
-    if len(ports_in) <= 1:
+    hosts_in = host_arg.split(",")
+    ports_in = port_arg.split(",")
+    if len(ports_in) == 1:
         # If only one port is specified, the libpq will apply it to all
         # the hosts, so don't mangle it.
         del ports_in[:]
-    else:
+    elif len(ports_in) > 1:
         if len(ports_in) != len(hosts_in):
             # ProgrammingError would have been more appropriate, but this is
             # what the raise if the libpq fails connect in the same case.
index 7ecd6e37582a8365a72e9489752f6e04606f2883..5835fbc196d3dc126da904147bb0c152ed228a1b 100644 (file)
@@ -5,73 +5,107 @@ from psycopg.conninfo import conninfo_to_dict
 
 
 @pytest.mark.parametrize(
-    "conninfo, want",
+    "conninfo, want, env",
     [
-        ("", ""),
-        ("host='' user=bar", "host='' user=bar"),
+        ("", "", None),
+        ("host='' user=bar", "host='' user=bar", None),
         (
             "host=127.0.0.1 user=bar",
             "host=127.0.0.1 user=bar hostaddr=127.0.0.1",
+            None,
         ),
         (
             "host=1.1.1.1,2.2.2.2 user=bar",
             "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2",
+            None,
         ),
         (
             "host=1.1.1.1,2.2.2.2 port=5432",
             "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+            None,
+        ),
+        (
+            "port=5432",
+            "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+            {"PGHOST": "1.1.1.1,2.2.2.2"},
+        ),
+        (
+            "host=foo.com port=5432",
+            "host=foo.com port=5432",
+            {"PGHOSTADDR": "1.2.3.4"},
         ),
     ],
 )
 @pytest.mark.asyncio
-async def test_resolve_hostaddr_async_no_resolve(conninfo, want, fail_resolve):
+async def test_resolve_hostaddr_async_no_resolve(
+    monkeypatch, conninfo, want, env, fail_resolve
+):
+    if env:
+        for k, v in env.items():
+            monkeypatch.setenv(k, v)
     params = conninfo_to_dict(conninfo)
     params = await psycopg._dns.resolve_hostaddr_async(params)
     assert conninfo_to_dict(want) == params
 
 
 @pytest.mark.parametrize(
-    "conninfo, want",
+    "conninfo, want, env",
     [
         (
             "host=foo.com,qux.com",
             "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+            None,
         ),
         (
             "host=foo.com,qux.com port=5433",
             "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433",
+            None,
         ),
         (
             "host=foo.com,qux.com port=5432,5433",
             "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433",
+            None,
         ),
         (
             "host=foo.com,nosuchhost.com",
             "host=foo.com hostaddr=1.1.1.1",
+            None,
         ),
         (
             "host=nosuchhost.com,foo.com",
             "host=foo.com hostaddr=1.1.1.1",
+            None,
+        ),
+        (
+            "host=foo.com,qux.com",
+            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+            {},
         ),
     ],
 )
 @pytest.mark.asyncio
-async def test_resolve_hostaddr_async(conninfo, want, fake_resolve):
+async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve):
     params = conninfo_to_dict(conninfo)
     params = await psycopg._dns.resolve_hostaddr_async(params)
     assert conninfo_to_dict(want) == params
 
 
 @pytest.mark.parametrize(
-    "conninfo",
+    "conninfo, env",
     [
-        "host=bad1.com,bad2.com",
-        "host=foo.com port=1,2",
-        "host=1.1.1.1,2.2.2.2 port=5432,5433,5434",
+        ("host=bad1.com,bad2.com", None),
+        ("host=foo.com port=1,2", None),
+        ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
+        ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
     ],
 )
 @pytest.mark.asyncio
-async def test_resolve_hostaddr_async_bad(conninfo, fake_resolve):
+async def test_resolve_hostaddr_async_bad(
+    monkeypatch, conninfo, env, fake_resolve
+):
+    if env:
+        for k, v in env.items():
+            monkeypatch.setenv(k, v)
     params = conninfo_to_dict(conninfo)
     with pytest.raises((TypeError, psycopg.Error)):
         await psycopg._dns.resolve_hostaddr_async(params)