]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test: test conninfo attempts functions and multiple host support in connection
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 13 Nov 2023 18:08:51 +0000 (18:08 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 13 Nov 2023 23:37:22 +0000 (00:37 +0100)
tests/test_connection.py
tests/test_connection_async.py
tests/test_conninfo.py

index ddfff5311f82f14e3b7b5033c3c85dea9b058ee6..52808af038af75afca012af502e77b39a877c3bf 100644 (file)
@@ -50,6 +50,38 @@ def test_connect_timeout(conn_cls, deaf_port):
     assert elapsed == pytest.approx(1.0, abs=0.05)
 
 
+@pytest.mark.slow
+@pytest.mark.timing
+def test_multi_hosts(conn_cls, proxy, dsn, deaf_port, monkeypatch):
+    args = conninfo_to_dict(dsn)
+    args["host"] = f"{proxy.client_host},{proxy.server_host}"
+    args["port"] = f"{deaf_port},{proxy.server_port}"
+    args.pop("hostaddr", None)
+    monkeypatch.setattr(conn_cls, "_DEFAULT_CONNECT_TIMEOUT", 2)
+    t0 = time.time()
+    with conn_cls.connect(**args) as conn:
+        elapsed = time.time() - t0
+        assert 2.0 < elapsed < 2.5
+        assert conn.info.port == int(proxy.server_port)
+        assert conn.info.host == proxy.server_host
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_multi_hosts_timeout(conn_cls, proxy, dsn, deaf_port):
+    args = conninfo_to_dict(dsn)
+    args["host"] = f"{proxy.client_host},{proxy.server_host}"
+    args["port"] = f"{deaf_port},{proxy.server_port}"
+    args.pop("hostaddr", None)
+    args["connect_timeout"] = "1"
+    t0 = time.time()
+    with conn_cls.connect(**args) as conn:
+        elapsed = time.time() - t0
+        assert 1.0 < elapsed < 1.5
+        assert conn.info.port == int(proxy.server_port)
+        assert conn.info.host == proxy.server_host
+
+
 def test_close(conn):
     assert not conn.closed
     assert not conn.broken
@@ -830,7 +862,10 @@ def test_cancel_closed(conn):
 
 
 def drop_default_args_from_conninfo(conninfo):
-    params = conninfo_to_dict(conninfo)
+    if isinstance(conninfo, str):
+        params = conninfo_to_dict(conninfo)
+    else:
+        params = conninfo.copy()
 
     def removeif(key, value):
         if params.get(key) == value:
@@ -839,6 +874,10 @@ def drop_default_args_from_conninfo(conninfo):
     removeif("host", "")
     removeif("hostaddr", "")
     removeif("port", "5432")
+    if "," in params.get("host", ""):
+        nhosts = len(params["host"].split(","))
+        removeif("port", ",".join(["5432"] * nhosts))
+        removeif("hostaddr", "," * (nhosts - 1))
     removeif("connect_timeout", str(DEFAULT_TIMEOUT))
 
     return params
index 87d8a4ee6cf9cc17998dea21d50381375aa246b2..b4c100b9be36f35f7f4dbc3d0d049cae63bbfd2e 100644 (file)
@@ -52,6 +52,38 @@ async def test_connect_timeout(aconn_cls, deaf_port):
     assert elapsed == pytest.approx(1.0, abs=0.05)
 
 
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_multi_hosts(aconn_cls, proxy, dsn, deaf_port, monkeypatch):
+    args = conninfo_to_dict(dsn)
+    args["host"] = f"{proxy.client_host},{proxy.server_host}"
+    args["port"] = f"{deaf_port},{proxy.server_port}"
+    args.pop("hostaddr", None)
+    monkeypatch.setattr(aconn_cls, "_DEFAULT_CONNECT_TIMEOUT", 2)
+    t0 = time.time()
+    async with await aconn_cls.connect(**args) as conn:
+        elapsed = time.time() - t0
+        assert 2.0 < elapsed < 2.5
+        assert conn.info.port == int(proxy.server_port)
+        assert conn.info.host == proxy.server_host
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_multi_hosts_timeout(aconn_cls, proxy, dsn, deaf_port):
+    args = conninfo_to_dict(dsn)
+    args["host"] = f"{proxy.client_host},{proxy.server_host}"
+    args["port"] = f"{deaf_port},{proxy.server_port}"
+    args.pop("hostaddr", None)
+    args["connect_timeout"] = "1"
+    t0 = time.time()
+    async with await aconn_cls.connect(**args) as conn:
+        elapsed = time.time() - t0
+        assert 1.0 < elapsed < 1.5
+        assert conn.info.port == int(proxy.server_port)
+        assert conn.info.host == proxy.server_host
+
+
 async def test_close(aconn):
     assert not aconn.closed
     assert not aconn.broken
index e037b0539343a9119710d81f76b6e4a84d5b90e4..1ebf42f79a73ba1964283cc4e48a765e9c0d219e 100644 (file)
@@ -1,16 +1,19 @@
 import socket
 import asyncio
 import datetime as dt
+from functools import reduce
 
 import pytest
 
 import psycopg
 from psycopg import ProgrammingError
 from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
-from psycopg.conninfo import resolve_hostaddr_async
+from psycopg.conninfo import conninfo_attempts, conninfo_attempts_async
 from psycopg._encodings import pg2pyenc
 
+from .utils import alist
 from .fix_crdb import crdb_encoding
+from .test_connection import drop_default_args_from_conninfo
 
 snowman = "\u2603"
 
@@ -316,6 +319,42 @@ class TestConnectionInfo:
         assert conn.info.vendor
 
 
+@pytest.mark.parametrize(
+    "conninfo, want, env",
+    [
+        ("", "", None),
+        ("host='' user=bar", "host='' user=bar", None),
+        (
+            "host=127.0.0.1 user=bar",
+            "host=127.0.0.1 user=bar",
+            None,
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 user=bar",
+            "host=1.1.1.1,2.2.2.2 user=bar",
+            None,
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 port=5432",
+            "host=1.1.1.1,2.2.2.2 port=5432,5432",
+            None,
+        ),
+        (
+            "host=foo.com port=5432",
+            "host=foo.com port=5432 hostaddr=1.2.3.4",
+            {"PGHOSTADDR": "1.2.3.4"},
+        ),
+    ],
+)
+@pytest.mark.anyio
+def test_conninfo_attempts(setpgenv, conninfo, want, env):
+    setpgenv(env)
+    params = conninfo_to_dict(conninfo)
+    attempts = list(conninfo_attempts(params))
+    params = drop_default_args_from_conninfo(reduce(merge_conninfos, attempts))
+    assert drop_default_args_from_conninfo(conninfo_to_dict(want)) == params
+
+
 @pytest.mark.parametrize(
     "conninfo, want, env",
     [
@@ -349,13 +388,14 @@ class TestConnectionInfo:
     ],
 )
 @pytest.mark.anyio
-async def test_resolve_hostaddr_async_no_resolve(
+async def test_conninfo_attempts_async_no_resolve(
     setpgenv, conninfo, want, env, fail_resolve
 ):
     setpgenv(env)
     params = conninfo_to_dict(conninfo)
-    params = await resolve_hostaddr_async(params)
-    assert conninfo_to_dict(want) == params
+    attempts = await alist(conninfo_attempts_async(params))
+    params = drop_default_args_from_conninfo(reduce(merge_conninfos, attempts))
+    assert drop_default_args_from_conninfo(conninfo_to_dict(want)) == params
 
 
 @pytest.mark.parametrize(
@@ -399,10 +439,11 @@ async def test_resolve_hostaddr_async_no_resolve(
     ],
 )
 @pytest.mark.anyio
-async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve):
+async def test_conninfo_attempts_async(conninfo, want, env, fake_resolve):
     params = conninfo_to_dict(conninfo)
-    params = await resolve_hostaddr_async(params)
-    assert conninfo_to_dict(want) == params
+    attempts = await alist(conninfo_attempts_async(params))
+    params = drop_default_args_from_conninfo(reduce(merge_conninfos, attempts))
+    assert drop_default_args_from_conninfo(conninfo_to_dict(want)) == params
 
 
 @pytest.mark.parametrize(
@@ -415,11 +456,27 @@ async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve):
     ],
 )
 @pytest.mark.anyio
-async def test_resolve_hostaddr_async_bad(setpgenv, conninfo, env, fake_resolve):
+async def test_conninfo_attempts_async_bad(setpgenv, conninfo, env, fake_resolve):
+    setpgenv(env)
+    params = conninfo_to_dict(conninfo)
+    with pytest.raises(psycopg.Error):
+        await alist(conninfo_attempts_async(params))
+
+
+@pytest.mark.parametrize(
+    "conninfo, env",
+    [
+        ("host=foo.com port=1,2", None),
+        ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
+        ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
+    ],
+)
+@pytest.mark.anyio
+def test_conninfo_attempts_bad(setpgenv, conninfo, env):
     setpgenv(env)
     params = conninfo_to_dict(conninfo)
     with pytest.raises(psycopg.Error):
-        await resolve_hostaddr_async(params)
+        list(conninfo_attempts(params))
 
 
 @pytest.fixture
@@ -448,3 +505,18 @@ async def fail_resolve(monkeypatch):
         pytest.fail(f"shouldn't try to resolve {host}")
 
     monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fail_getaddrinfo)
+
+
+def merge_conninfos(a1, a2):
+    """
+    merge conninfo attempts into a multi-host conninfo.
+    """
+    assert set(a1) == set(a2)
+    rv = {}
+    for k in a1:
+        if k in ("host", "hostaddr", "port"):
+            rv[k] = f"{a1[k]},{a2[k]}"
+        else:
+            assert a1[k] == a2[k]
+            rv[k] = a1[k]
+    return rv