From 0e9186a668c5b90de798edbea3fdb355a85f0e6c Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Thu, 6 Oct 2022 04:33:16 -0700 Subject: [PATCH] Asyncio sockets should work after a timeout [#843]. --- dns/_asyncio_backend.py | 12 ++++++----- tests/test_async.py | 46 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index 50bde1dd..736539bc 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -31,7 +31,6 @@ class _DatagramProtocol: def datagram_received(self, data, addr): if self.recvfrom and not self.recvfrom.done(): self.recvfrom.set_result((data, addr)) - self.recvfrom = None def error_received(self, exc): # pragma: no cover if self.recvfrom and not self.recvfrom.done(): @@ -68,10 +67,13 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def recvfrom(self, size, timeout): # ignore size as there's no way I know to tell protocol about it done = _get_running_loop().create_future() - assert self.protocol.recvfrom is None - self.protocol.recvfrom = done - await _maybe_wait_for(done, timeout) - return done.result() + try: + assert self.protocol.recvfrom is None + self.protocol.recvfrom = done + await _maybe_wait_for(done, timeout) + return done.result() + finally: + self.protocol.recvfrom = None async def close(self): self.protocol.close() diff --git a/tests/test_async.py b/tests/test_async.py index 2de3ca6d..65f1ed9b 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -523,6 +523,52 @@ class AsyncTests(unittest.TestCase): self.async_run(run) +@unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable") +class AsyncioOnlyTests(unittest.TestCase): + connect_udp = sys.platform == "win32" + + def setUp(self): + self.backend = dns.asyncbackend.set_default_backend("asyncio") + + def async_run(self, afunc): + return asyncio.run(afunc()) + + def testUseAfterTimeout(self): + if self.connect_udp: + self.skipTest("test needs connectionless sockets") + # Test #843 fix. + async def run(): + qname = dns.name.from_text("dns.google") + query = dns.message.make_query(qname, "A") + sock = await self.backend.make_socket(socket.AF_INET, socket.SOCK_DGRAM) + async with sock: + # First do something that will definitely timeout. + try: + response = await dns.asyncquery.udp( + query, "8.8.8.8", timeout=0.0001, sock=sock + ) + except dns.exception.Timeout: + pass + except Exception: + self.assertTrue(False) + # Now try to reuse the socket with a reasonable timeout. + try: + response = await dns.asyncquery.udp( + query, "8.8.8.8", timeout=5, sock=sock + ) + rrs = response.get_rrset( + response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A + ) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue("8.8.8.8" in seen) + self.assertTrue("8.8.4.4" in seen) + except Exception: + self.assertTrue(False) + + self.async_run(run) + + try: import trio import sniffio -- 2.47.3