def datagram_received(self, data, addr):
if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_result((data, addr))
- self.recvfrom = None
def error_received(self, exc): # pragma: no cover
if self.recvfrom and not self.recvfrom.done():
async def recvfrom(self, size, timeout):
# ignore size as there's no way I know to tell protocol about it
done = _get_running_loop().create_future()
- assert self.protocol.recvfrom is None
- self.protocol.recvfrom = done
- await _maybe_wait_for(done, timeout)
- return done.result()
+ try:
+ assert self.protocol.recvfrom is None
+ self.protocol.recvfrom = done
+ await _maybe_wait_for(done, timeout)
+ return done.result()
+ finally:
+ self.protocol.recvfrom = None
async def close(self):
self.protocol.close()
self.async_run(run)
+@unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable")
+class AsyncioOnlyTests(unittest.TestCase):
+ connect_udp = sys.platform == "win32"
+
+ def setUp(self):
+ self.backend = dns.asyncbackend.set_default_backend("asyncio")
+
+ def async_run(self, afunc):
+ return asyncio.run(afunc())
+
+ def testUseAfterTimeout(self):
+ if self.connect_udp:
+ self.skipTest("test needs connectionless sockets")
+ # Test #843 fix.
+ async def run():
+ qname = dns.name.from_text("dns.google")
+ query = dns.message.make_query(qname, "A")
+ sock = await self.backend.make_socket(socket.AF_INET, socket.SOCK_DGRAM)
+ async with sock:
+ # First do something that will definitely timeout.
+ try:
+ response = await dns.asyncquery.udp(
+ query, "8.8.8.8", timeout=0.0001, sock=sock
+ )
+ except dns.exception.Timeout:
+ pass
+ except Exception:
+ self.assertTrue(False)
+ # Now try to reuse the socket with a reasonable timeout.
+ try:
+ response = await dns.asyncquery.udp(
+ query, "8.8.8.8", timeout=5, sock=sock
+ )
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
+ except Exception:
+ self.assertTrue(False)
+
+ self.async_run(run)
+
+
try:
import trio
import sniffio