]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(test): make resolution monkeypatching common fixtures
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 6 Jan 2024 16:10:58 +0000 (17:10 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 6 Jan 2024 19:22:02 +0000 (20:22 +0100)
tests/conftest.py
tests/fix_dns.py [new file with mode: 0644]
tests/test_connection_async.py
tests/test_conninfo.py
tests/test_dns.py

index 6647cfd9eab765669c3cbcb1ee02d71d9d199edf..db49ffdc88818ecc04d5c18b53d7c277f10565f7 100644 (file)
@@ -8,6 +8,7 @@ import pytest
 pytest_plugins = (
     "tests.fix_db",
     "tests.fix_pq",
+    "tests.fix_dns",
     "tests.fix_mypy",
     "tests.fix_faker",
     "tests.fix_proxy",
diff --git a/tests/fix_dns.py b/tests/fix_dns.py
new file mode 100644 (file)
index 0000000..4a538e5
--- /dev/null
@@ -0,0 +1,61 @@
+import asyncio
+import socket
+
+import pytest
+
+
+@pytest.fixture
+def fake_resolve(monkeypatch):
+    """
+    Fixture to return known name from name resolution.
+    """
+    fake_hosts = {
+        "localhost": ["127.0.0.1"],
+        "foo.com": ["1.1.1.1"],
+        "qux.com": ["2.2.2.2"],
+        "dup.com": ["3.3.3.3", "3.3.3.4"],
+        "alot.com": [f"4.4.4.{n}" for n in range(10, 30)],
+    }
+
+    def family(host):
+        return socket.AF_INET6 if ":" in host else socket.AF_INET
+
+    def fake_getaddrinfo(host, port, *args, **kwargs):
+        assert isinstance(port, int) or (isinstance(port, str) and port.isdigit())
+        try:
+            addrs = fake_hosts[host]
+        except KeyError:
+            raise OSError(f"unknown test host: {host}")
+        else:
+            return [
+                (family(addr), socket.SOCK_STREAM, 6, "", (addr, port))
+                for addr in addrs
+            ]
+
+    _patch_gai(monkeypatch, fake_getaddrinfo)
+
+
+@pytest.fixture
+def fail_resolve(monkeypatch):
+    """
+    Fixture to fail any name resolution.
+    """
+
+    def fail_getaddrinfo(host, port, **kwargs):
+        pytest.fail(f"shouldn't try to resolve {host}")
+
+    _patch_gai(monkeypatch, fail_getaddrinfo)
+
+
+def _patch_gai(monkeypatch, f):
+    monkeypatch.setattr(socket, "getaddrinfo", f)
+    try:
+        loop = asyncio.get_running_loop()
+    except RuntimeError:
+        pass
+    else:
+
+        async def af(*args, **kwargs):
+            return f(*args, **kwargs)
+
+        monkeypatch.setattr(loop, "getaddrinfo", af)
index 86ccfe10bf07d55b25315575e6447698c048be57..ced5db583d3e0c07557961858aa6c6d72dcc8126 100644 (file)
@@ -15,7 +15,6 @@ from .test_connection import tx_params, tx_params_isolation, tx_values_map
 from .test_connection import conninfo_params_timeout
 from .test_connection import testctx  # noqa: F401  # fixture
 from .test_adapt import make_bin_dumper, make_dumper
-from .test_conninfo import fake_resolve  # noqa: F401
 
 pytestmark = pytest.mark.anyio
 
@@ -804,7 +803,7 @@ async def test_cancel_closed(aconn):
     aconn.cancel()
 
 
-async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve):  # noqa: F811
+async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve):
     got = []
 
     def fake_connect_gen(conninfo, **kwargs):
index ae892ed3c19d5dc1c5d55e930f1c7a0ebe84f1bf..2e8c44822fc896621ab75788c4017933c7b0932c 100644 (file)
@@ -349,39 +349,3 @@ def test_timeout(setpgenv, conninfo, want, env):
     params = conninfo_to_dict(conninfo)
     timeout = timeout_from_conninfo(params)
     assert timeout == want
-
-
-@pytest.fixture
-async def fake_resolve(monkeypatch):
-    fake_hosts = {
-        "localhost": ["127.0.0.1"],
-        "foo.com": ["1.1.1.1"],
-        "qux.com": ["2.2.2.2"],
-        "dup.com": ["3.3.3.3", "3.3.3.4"],
-        "alot.com": [f"4.4.4.{n}" for n in range(10, 30)],
-    }
-
-    def family(host):
-        return socket.AF_INET6 if ":" in host else socket.AF_INET
-
-    async def fake_getaddrinfo(host, port, **kwargs):
-        assert isinstance(port, int) or (isinstance(port, str) and port.isdigit())
-        try:
-            addrs = fake_hosts[host]
-        except KeyError:
-            raise OSError(f"unknown test host: {host}")
-        else:
-            return [
-                (family(addr), socket.SOCK_STREAM, 6, "", (addr, port))
-                for addr in addrs
-            ]
-
-    monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
-
-
-@pytest.fixture
-async def fail_resolve(monkeypatch):
-    async def fail_getaddrinfo(host, port, **kwargs):
-        pytest.fail(f"shouldn't try to resolve {host}")
-
-    monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fail_getaddrinfo)
index a83aaeb6694165ebd0c257c650ad7361063185f0..7ea89ad681dd2a751ac600a796d288fd03cd6199 100644 (file)
@@ -3,8 +3,6 @@ import pytest
 import psycopg
 from psycopg.conninfo import conninfo_to_dict
 
-from .test_conninfo import fake_resolve  # noqa: F401  # fixture
-
 
 @pytest.mark.usefixtures("fake_resolve")
 async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch):