]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Asyncio sockets should work after a timeout [#843]. 845/head
authorBob Halley <halley@dnspython.org>
Thu, 6 Oct 2022 11:33:16 +0000 (04:33 -0700)
committerBob Halley <halley@dnspython.org>
Thu, 6 Oct 2022 12:20:04 +0000 (05:20 -0700)
dns/_asyncio_backend.py
tests/test_async.py

index 50bde1dd07737f45b6576ac3527dc5e741cea501..736539bc9e5874d40560e3b0e5f97aa846d0d2a8 100644 (file)
@@ -31,7 +31,6 @@ class _DatagramProtocol:
     def datagram_received(self, data, addr):
         if self.recvfrom and not self.recvfrom.done():
             self.recvfrom.set_result((data, addr))
-            self.recvfrom = None
 
     def error_received(self, exc):  # pragma: no cover
         if self.recvfrom and not self.recvfrom.done():
@@ -68,10 +67,13 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
     async def recvfrom(self, size, timeout):
         # ignore size as there's no way I know to tell protocol about it
         done = _get_running_loop().create_future()
-        assert self.protocol.recvfrom is None
-        self.protocol.recvfrom = done
-        await _maybe_wait_for(done, timeout)
-        return done.result()
+        try:
+            assert self.protocol.recvfrom is None
+            self.protocol.recvfrom = done
+            await _maybe_wait_for(done, timeout)
+            return done.result()
+        finally:
+            self.protocol.recvfrom = None
 
     async def close(self):
         self.protocol.close()
index 2de3ca6d6fbcfa1edaa4143cb6ad73fa3aab762d..65f1ed9bfca20effce2f71906c45e68e1f68368c 100644 (file)
@@ -523,6 +523,52 @@ class AsyncTests(unittest.TestCase):
         self.async_run(run)
 
 
+@unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable")
+class AsyncioOnlyTests(unittest.TestCase):
+    connect_udp = sys.platform == "win32"
+
+    def setUp(self):
+        self.backend = dns.asyncbackend.set_default_backend("asyncio")
+
+    def async_run(self, afunc):
+        return asyncio.run(afunc())
+
+    def testUseAfterTimeout(self):
+        if self.connect_udp:
+            self.skipTest("test needs connectionless sockets")
+        # Test #843 fix.
+        async def run():
+            qname = dns.name.from_text("dns.google")
+            query = dns.message.make_query(qname, "A")
+            sock = await self.backend.make_socket(socket.AF_INET, socket.SOCK_DGRAM)
+            async with sock:
+                # First do something that will definitely timeout.
+                try:
+                    response = await dns.asyncquery.udp(
+                        query, "8.8.8.8", timeout=0.0001, sock=sock
+                    )
+                except dns.exception.Timeout:
+                    pass
+                except Exception:
+                    self.assertTrue(False)
+                # Now try to reuse the socket with a reasonable timeout.
+                try:
+                    response = await dns.asyncquery.udp(
+                        query, "8.8.8.8", timeout=5, sock=sock
+                    )
+                    rrs = response.get_rrset(
+                        response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+                    )
+                    self.assertTrue(rrs is not None)
+                    seen = set([rdata.address for rdata in rrs])
+                    self.assertTrue("8.8.8.8" in seen)
+                    self.assertTrue("8.8.4.4" in seen)
+                except Exception:
+                    self.assertTrue(False)
+
+        self.async_run(run)
+
+
 try:
     import trio
     import sniffio