]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Properly fix asyncio QUIC shutdown races [#1069].
authorBob Halley <halley@dnspython.org>
Mon, 25 Mar 2024 19:44:36 +0000 (12:44 -0700)
committerBob Halley <halley@dnspython.org>
Mon, 25 Mar 2024 19:44:36 +0000 (12:44 -0700)
There were two basic issues:

1) We did not wake up the sender thread to do work in some cases, and could
   sleep for a long time.
2) asyncio.wait_for() does not instantly run the function, it just schedules
   it, and our guards against lost wakeups were thus in the wrong place.

dns/quic/_asyncio.py

index 01547e8d47ef599d8ac2bba022a58c4464d02cb0..f87515dacfd2252fbb8204afdd25b8c91dff0207 100644 (file)
@@ -97,8 +97,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
         self._wake_timer = asyncio.Condition()
         self._receiver_task = None
         self._sender_task = None
-        self._send_pending = False
-        self._check_for_events = False
+        self._wake_pending = False
 
     async def _receiver(self):
         try:
@@ -119,40 +118,39 @@ class AsyncioQuicConnection(AsyncQuicConnection):
                         continue
                     self._connection.receive_datagram(datagram, address, time.time())
                     # Wake up the timer in case the sender is sleeping, as there may be
-                    # stuff to send now.  We need to set a flag as well as wake up the
-                    # timer to avoid a race where we get a datagram and generate an
-                    # event right before the sender is going to sleep.
-                    self._check_for_events = True
-                    async with self._wake_timer:
-                        self._wake_timer.notify_all()
+                    # stuff to send now.
+                    await self._wakeup()
         except Exception:
             pass
         finally:
             self._done = True
-            async with self._wake_timer:
-                self._wake_timer.notify_all()
+            await self._wakeup()
             self._handshake_complete.set()
 
+    async def _wakeup(self):
+        self._wake_pending = True
+        async with self._wake_timer:
+            self._wake_timer.notify_all()
+
     async def _wait_for_wake_timer(self):
         async with self._wake_timer:
-            await self._wake_timer.wait()
+            if not self._wake_pending:
+                await self._wake_timer.wait()
+        self._wake_pending = False
 
     async def _sender(self):
         await self._socket_created.wait()
         while not self._done:
-            self._send_pending = False
             datagrams = self._connection.datagrams_to_send(time.time())
             for datagram, address in datagrams:
                 assert address == self._peer
                 await self._socket.sendto(datagram, self._peer, None)
             (expiration, interval) = self._get_timer_values()
-            if not (self._check_for_events or self._send_pending):
-                try:
-                    await asyncio.wait_for(self._wait_for_wake_timer(), interval)
-                except Exception:
-                    pass
+            try:
+                await asyncio.wait_for(self._wait_for_wake_timer(), interval)
+            except Exception:
+                pass
             self._handle_timer(expiration)
-            self._check_for_events = False
             await self._handle_events()
 
     async def _handle_events(self):
@@ -202,9 +200,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
 
     async def write(self, stream, data, is_end=False):
         self._connection.send_stream_data(stream, data, is_end)
-        self._send_pending = True
-        async with self._wake_timer:
-            self._wake_timer.notify_all()
+        await self._wakeup()
 
     def run(self):
         if self._closed:
@@ -231,8 +227,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
             self._connection.close()
             # sender might be blocked on this, so set it
             self._socket_created.set()
-            async with self._wake_timer:
-                self._wake_timer.notify_all()
+            await self._wakeup()
             try:
                 await self._receiver_task
             except asyncio.CancelledError: