From: Bob Halley Date: Mon, 25 Mar 2024 19:44:36 +0000 (-0700) Subject: Properly fix asyncio QUIC shutdown races [#1069]. X-Git-Tag: v2.7.0rc1~59 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=7fb6e92637a22b798f47f60518d979bee13aee29;p=thirdparty%2Fdnspython.git Properly fix asyncio QUIC shutdown races [#1069]. There were two basic issues: 1) We did not wake up the sender thread to do work in some cases, and could sleep for a long time. 2) asyncio.wait_for() does not instantly run the function, it just schedules it, and our guards against lost wakeups were thus in the wrong place. --- diff --git a/dns/quic/_asyncio.py b/dns/quic/_asyncio.py index 01547e8d..f87515da 100644 --- a/dns/quic/_asyncio.py +++ b/dns/quic/_asyncio.py @@ -97,8 +97,7 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._wake_timer = asyncio.Condition() self._receiver_task = None self._sender_task = None - self._send_pending = False - self._check_for_events = False + self._wake_pending = False async def _receiver(self): try: @@ -119,40 +118,39 @@ class AsyncioQuicConnection(AsyncQuicConnection): continue self._connection.receive_datagram(datagram, address, time.time()) # Wake up the timer in case the sender is sleeping, as there may be - # stuff to send now. We need to set a flag as well as wake up the - # timer to avoid a race where we get a datagram and generate an - # event right before the sender is going to sleep. - self._check_for_events = True - async with self._wake_timer: - self._wake_timer.notify_all() + # stuff to send now. + await self._wakeup() except Exception: pass finally: self._done = True - async with self._wake_timer: - self._wake_timer.notify_all() + await self._wakeup() self._handshake_complete.set() + async def _wakeup(self): + self._wake_pending = True + async with self._wake_timer: + self._wake_timer.notify_all() + async def _wait_for_wake_timer(self): async with self._wake_timer: - await self._wake_timer.wait() + if not self._wake_pending: + await self._wake_timer.wait() + self._wake_pending = False async def _sender(self): await self._socket_created.wait() while not self._done: - self._send_pending = False datagrams = self._connection.datagrams_to_send(time.time()) for datagram, address in datagrams: assert address == self._peer await self._socket.sendto(datagram, self._peer, None) (expiration, interval) = self._get_timer_values() - if not (self._check_for_events or self._send_pending): - try: - await asyncio.wait_for(self._wait_for_wake_timer(), interval) - except Exception: - pass + try: + await asyncio.wait_for(self._wait_for_wake_timer(), interval) + except Exception: + pass self._handle_timer(expiration) - self._check_for_events = False await self._handle_events() async def _handle_events(self): @@ -202,9 +200,7 @@ class AsyncioQuicConnection(AsyncQuicConnection): async def write(self, stream, data, is_end=False): self._connection.send_stream_data(stream, data, is_end) - self._send_pending = True - async with self._wake_timer: - self._wake_timer.notify_all() + await self._wakeup() def run(self): if self._closed: @@ -231,8 +227,7 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._connection.close() # sender might be blocked on this, so set it self._socket_created.set() - async with self._wake_timer: - self._wake_timer.notify_all() + await self._wakeup() try: await self._receiver_task except asyncio.CancelledError: