From ba9e17df9835d70849c7498448c4654d2486a061 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Mon, 30 Aug 2021 07:51:30 +0200 Subject: [PATCH] Fix async SRV resolution Also use separate fixtures to mock sync and async DNS resolution. --- psycopg/psycopg/_dns.py | 7 ++----- tests/test_dns_srv.py | 25 ++++++++++++++++++------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py index 689809415..58e644206 100644 --- a/psycopg/psycopg/_dns.py +++ b/psycopg/psycopg/_dns.py @@ -259,10 +259,7 @@ class Rfc2782Resolver: hp = HostPort(host=host, port=port) out.append(hp) - if srv_found: - return out - else: - return [] + return out if srv_found else [] def _resolve_srv(self, hp: HostPort) -> List[HostPort]: try: @@ -273,7 +270,7 @@ class Rfc2782Resolver: async def _resolve_srv_async(self, hp: HostPort) -> List[HostPort]: try: - ans = resolver.resolve(hp.host, "SRV") + ans = await async_resolver.resolve(hp.host, "SRV") except DNSException: ans = () return self._get_solved_entries(hp, ans) diff --git a/tests/test_dns_srv.py b/tests/test_dns_srv.py index e642a5030..a20f2a7c8 100644 --- a/tests/test_dns_srv.py +++ b/tests/test_dns_srv.py @@ -57,7 +57,7 @@ def test_srv(conninfo, want, env, fake_srv, retries, monkeypatch): @pytest.mark.asyncio @pytest.mark.parametrize("conninfo, want, env", samples_ok) -async def test_srv_async(conninfo, want, env, fake_srv, retries, monkeypatch): +async def test_srv_async(conninfo, want, env, afake_srv, retries, monkeypatch): if env: for k, v in env.items(): monkeypatch.setenv(k, v) @@ -86,7 +86,7 @@ def test_srv_bad(conninfo, env, fake_srv, monkeypatch): @pytest.mark.asyncio @pytest.mark.parametrize("conninfo, env", samples_bad) -async def test_srv_bad_async(conninfo, env, fake_srv, monkeypatch): +async def test_srv_bad_async(conninfo, env, afake_srv, monkeypatch): if env: for k, v in env.items(): monkeypatch.setenv(k, v) @@ -97,6 +97,21 @@ async def test_srv_bad_async(conninfo, env, fake_srv, monkeypatch): @pytest.fixture def fake_srv(monkeypatch): + f = get_fake_srv_function(monkeypatch) + monkeypatch.setattr(psycopg._dns.resolver, "resolve", f) + + +@pytest.fixture +def afake_srv(monkeypatch): + f = get_fake_srv_function(monkeypatch) + + async def af(qname, rdtype): + return f(qname, rdtype) + + monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", af) + + +def get_fake_srv_function(monkeypatch): import_dnspython() from dns.rdtypes.IN.A import A @@ -134,8 +149,4 @@ def fake_srv(monkeypatch): return rv - async def afake_srv_(qname, rdtype): - return fake_srv(qname, rdtype) - - monkeypatch.setattr(psycopg._dns.resolver, "resolve", fake_srv_) - monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", afake_srv_) + return fake_srv_ -- 2.47.3