From: Daniele Varrazzo Date: Mon, 30 Aug 2021 05:51:30 +0000 (+0200) Subject: Fix async SRV resolution X-Git-Tag: 3.0~94 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=ba9e17df9835d70849c7498448c4654d2486a061;p=thirdparty%2Fpsycopg.git Fix async SRV resolution Also use separate fixtures to mock sync and async DNS resolution. --- diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py index 689809415..58e644206 100644 --- a/psycopg/psycopg/_dns.py +++ b/psycopg/psycopg/_dns.py @@ -259,10 +259,7 @@ class Rfc2782Resolver: hp = HostPort(host=host, port=port) out.append(hp) - if srv_found: - return out - else: - return [] + return out if srv_found else [] def _resolve_srv(self, hp: HostPort) -> List[HostPort]: try: @@ -273,7 +270,7 @@ class Rfc2782Resolver: async def _resolve_srv_async(self, hp: HostPort) -> List[HostPort]: try: - ans = resolver.resolve(hp.host, "SRV") + ans = await async_resolver.resolve(hp.host, "SRV") except DNSException: ans = () return self._get_solved_entries(hp, ans) diff --git a/tests/test_dns_srv.py b/tests/test_dns_srv.py index e642a5030..a20f2a7c8 100644 --- a/tests/test_dns_srv.py +++ b/tests/test_dns_srv.py @@ -57,7 +57,7 @@ def test_srv(conninfo, want, env, fake_srv, retries, monkeypatch): @pytest.mark.asyncio @pytest.mark.parametrize("conninfo, want, env", samples_ok) -async def test_srv_async(conninfo, want, env, fake_srv, retries, monkeypatch): +async def test_srv_async(conninfo, want, env, afake_srv, retries, monkeypatch): if env: for k, v in env.items(): monkeypatch.setenv(k, v) @@ -86,7 +86,7 @@ def test_srv_bad(conninfo, env, fake_srv, monkeypatch): @pytest.mark.asyncio @pytest.mark.parametrize("conninfo, env", samples_bad) -async def test_srv_bad_async(conninfo, env, fake_srv, monkeypatch): +async def test_srv_bad_async(conninfo, env, afake_srv, monkeypatch): if env: for k, v in env.items(): monkeypatch.setenv(k, v) @@ -97,6 +97,21 @@ async def test_srv_bad_async(conninfo, env, fake_srv, monkeypatch): @pytest.fixture def fake_srv(monkeypatch): + f = get_fake_srv_function(monkeypatch) + monkeypatch.setattr(psycopg._dns.resolver, "resolve", f) + + +@pytest.fixture +def afake_srv(monkeypatch): + f = get_fake_srv_function(monkeypatch) + + async def af(qname, rdtype): + return f(qname, rdtype) + + monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", af) + + +def get_fake_srv_function(monkeypatch): import_dnspython() from dns.rdtypes.IN.A import A @@ -134,8 +149,4 @@ def fake_srv(monkeypatch): return rv - async def afake_srv_(qname, rdtype): - return fake_srv(qname, rdtype) - - monkeypatch.setattr(psycopg._dns.resolver, "resolve", fake_srv_) - monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", afake_srv_) + return fake_srv_