]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Fix more races around long waits.
authorBob Halley <halley@dnspython.org>
Sat, 23 Mar 2024 12:42:20 +0000 (05:42 -0700)
committerBob Halley <halley@dnspython.org>
Sat, 23 Mar 2024 12:42:20 +0000 (05:42 -0700)
dns/quic/_asyncio.py

index 069387f4f0c94592d119e35706c236f71dac04d5..01547e8d47ef599d8ac2bba022a58c4464d02cb0 100644 (file)
@@ -97,6 +97,8 @@ class AsyncioQuicConnection(AsyncQuicConnection):
         self._wake_timer = asyncio.Condition()
         self._receiver_task = None
         self._sender_task = None
+        self._send_pending = False
+        self._check_for_events = False
 
     async def _receiver(self):
         try:
@@ -117,7 +119,10 @@ class AsyncioQuicConnection(AsyncQuicConnection):
                         continue
                     self._connection.receive_datagram(datagram, address, time.time())
                     # Wake up the timer in case the sender is sleeping, as there may be
-                    # stuff to send now.
+                    # stuff to send now.  We need to set a flag as well as wake up the
+                    # timer to avoid a race where we get a datagram and generate an
+                    # event right before the sender is going to sleep.
+                    self._check_for_events = True
                     async with self._wake_timer:
                         self._wake_timer.notify_all()
         except Exception:
@@ -135,16 +140,19 @@ class AsyncioQuicConnection(AsyncQuicConnection):
     async def _sender(self):
         await self._socket_created.wait()
         while not self._done:
+            self._send_pending = False
             datagrams = self._connection.datagrams_to_send(time.time())
             for datagram, address in datagrams:
                 assert address == self._peer
                 await self._socket.sendto(datagram, self._peer, None)
             (expiration, interval) = self._get_timer_values()
-            try:
-                await asyncio.wait_for(self._wait_for_wake_timer(), interval)
-            except Exception:
-                pass
+            if not (self._check_for_events or self._send_pending):
+                try:
+                    await asyncio.wait_for(self._wait_for_wake_timer(), interval)
+                except Exception:
+                    pass
             self._handle_timer(expiration)
+            self._check_for_events = False
             await self._handle_events()
 
     async def _handle_events(self):
@@ -194,6 +202,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
 
     async def write(self, stream, data, is_end=False):
         self._connection.send_stream_data(stream, data, is_end)
+        self._send_pending = True
         async with self._wake_timer:
             self._wake_timer.notify_all()