]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix async SRV resolution
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 30 Aug 2021 05:51:30 +0000 (07:51 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 30 Aug 2021 05:51:49 +0000 (07:51 +0200)
Also use separate fixtures to mock sync and async DNS resolution.

psycopg/psycopg/_dns.py
tests/test_dns_srv.py

index 689809415cf7e1326c4cae7083bcb0910ade9088..58e6442068c6a41a2c4dbbb2a395da72c7d90b7a 100644 (file)
@@ -259,10 +259,7 @@ class Rfc2782Resolver:
                 hp = HostPort(host=host, port=port)
             out.append(hp)
 
-        if srv_found:
-            return out
-        else:
-            return []
+        return out if srv_found else []
 
     def _resolve_srv(self, hp: HostPort) -> List[HostPort]:
         try:
@@ -273,7 +270,7 @@ class Rfc2782Resolver:
 
     async def _resolve_srv_async(self, hp: HostPort) -> List[HostPort]:
         try:
-            ans = resolver.resolve(hp.host, "SRV")
+            ans = await async_resolver.resolve(hp.host, "SRV")
         except DNSException:
             ans = ()
         return self._get_solved_entries(hp, ans)
index e642a50308a4ca2fd49903c2d44add0123f93618..a20f2a7c8cbe4dae38eea57b379869c83a2ce904 100644 (file)
@@ -57,7 +57,7 @@ def test_srv(conninfo, want, env, fake_srv, retries, monkeypatch):
 
 @pytest.mark.asyncio
 @pytest.mark.parametrize("conninfo, want, env", samples_ok)
-async def test_srv_async(conninfo, want, env, fake_srv, retries, monkeypatch):
+async def test_srv_async(conninfo, want, env, afake_srv, retries, monkeypatch):
     if env:
         for k, v in env.items():
             monkeypatch.setenv(k, v)
@@ -86,7 +86,7 @@ def test_srv_bad(conninfo, env, fake_srv, monkeypatch):
 
 @pytest.mark.asyncio
 @pytest.mark.parametrize("conninfo,  env", samples_bad)
-async def test_srv_bad_async(conninfo, env, fake_srv, monkeypatch):
+async def test_srv_bad_async(conninfo, env, afake_srv, monkeypatch):
     if env:
         for k, v in env.items():
             monkeypatch.setenv(k, v)
@@ -97,6 +97,21 @@ async def test_srv_bad_async(conninfo, env, fake_srv, monkeypatch):
 
 @pytest.fixture
 def fake_srv(monkeypatch):
+    f = get_fake_srv_function(monkeypatch)
+    monkeypatch.setattr(psycopg._dns.resolver, "resolve", f)
+
+
+@pytest.fixture
+def afake_srv(monkeypatch):
+    f = get_fake_srv_function(monkeypatch)
+
+    async def af(qname, rdtype):
+        return f(qname, rdtype)
+
+    monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", af)
+
+
+def get_fake_srv_function(monkeypatch):
     import_dnspython()
 
     from dns.rdtypes.IN.A import A
@@ -134,8 +149,4 @@ def fake_srv(monkeypatch):
 
         return rv
 
-    async def afake_srv_(qname, rdtype):
-        return fake_srv(qname, rdtype)
-
-    monkeypatch.setattr(psycopg._dns.resolver, "resolve", fake_srv_)
-    monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", afake_srv_)
+    return fake_srv_