]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Fix hangs when QUIC connection fails [#899]. (#900)
authorBob Halley <halley@dnspython.org>
Thu, 2 Mar 2023 15:51:50 +0000 (07:51 -0800)
committerGitHub <noreply@github.com>
Thu, 2 Mar 2023 15:51:50 +0000 (07:51 -0800)
This also fixes problems with computing the wait_for() timeout for
the sync and asyncio ports, and fixes delivery of the timeout for
the sync port.

dns/quic/_asyncio.py
dns/quic/_sync.py
dns/quic/_trio.py

index bcce048114756bb03e2c6951ff890157d7cb2aea..80f244d125c8c358cf51a40a01c9037d0b0282d1 100644 (file)
@@ -17,6 +17,7 @@ from dns.quic._common import (
     AsyncQuicConnection,
     AsyncQuicManager,
     QUIC_MAX_DATAGRAM,
+    UnexpectedEOF,
 )
 
 
@@ -30,8 +31,8 @@ class AsyncioQuicStream(BaseQuicStream):
             await self._wake_up.wait()
 
     async def wait_for(self, amount, expiration):
-        timeout = self._timeout_from_expiration(expiration)
         while True:
+            timeout = self._timeout_from_expiration(expiration)
             if self._buffer.have(amount):
                 return
             self._expecting = amount
@@ -106,6 +107,11 @@ class AsyncioQuicConnection(AsyncQuicConnection):
                         self._wake_timer.notify_all()
         except Exception:
             pass
+        finally:
+            self._done = True
+            async with self._wake_timer:
+                self._wake_timer.notify_all()
+            self._handshake_complete.set()
 
     async def _wait_for_wake_timer(self):
         async with self._wake_timer:
@@ -115,7 +121,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
         await self._socket_created.wait()
         while not self._done:
             datagrams = self._connection.datagrams_to_send(time.time())
-            for (datagram, address) in datagrams:
+            for datagram, address in datagrams:
                 assert address == self._peer[0]
                 await self._socket.sendto(datagram, self._peer, None)
             (expiration, interval) = self._get_timer_values()
@@ -162,6 +168,8 @@ class AsyncioQuicConnection(AsyncQuicConnection):
 
     async def make_stream(self):
         await self._handshake_complete.wait()
+        if self._done:
+            raise UnexpectedEOF
         stream_id = self._connection.get_next_available_stream_id(False)
         stream = AsyncioQuicStream(self, stream_id)
         self._streams[stream_id] = stream
index 8cc606a9fa4795701cec55a749515c095d925f9e..bc034fa93569766b28a57f1e0cf0b9b117187c4a 100644 (file)
@@ -17,6 +17,7 @@ from dns.quic._common import (
     BaseQuicConnection,
     BaseQuicManager,
     QUIC_MAX_DATAGRAM,
+    UnexpectedEOF,
 )
 
 # Avoid circularity with dns.query
@@ -33,14 +34,15 @@ class SyncQuicStream(BaseQuicStream):
         self._lock = threading.Lock()
 
     def wait_for(self, amount, expiration):
-        timeout = self._timeout_from_expiration(expiration)
         while True:
+            timeout = self._timeout_from_expiration(expiration)
             with self._lock:
                 if self._buffer.have(amount):
                     return
                 self._expecting = amount
             with self._wake_up:
-                self._wake_up.wait(timeout)
+                if not self._wake_up.wait(timeout):
+                    raise TimeoutError
             self._expecting = 0
 
     def receive(self, timeout=None):
@@ -114,24 +116,30 @@ class SyncQuicConnection(BaseQuicConnection):
                 return
 
     def _worker(self):
-        sel = _selector_class()
-        sel.register(self._socket, selectors.EVENT_READ, self._read)
-        sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
-        while not self._done:
-            (expiration, interval) = self._get_timer_values(False)
-            items = sel.select(interval)
-            for (key, _) in items:
-                key.data()
+        try:
+            sel = _selector_class()
+            sel.register(self._socket, selectors.EVENT_READ, self._read)
+            sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
+            while not self._done:
+                (expiration, interval) = self._get_timer_values(False)
+                items = sel.select(interval)
+                for key, _ in items:
+                    key.data()
+                with self._lock:
+                    self._handle_timer(expiration)
+                    datagrams = self._connection.datagrams_to_send(time.time())
+                for datagram, _ in datagrams:
+                    try:
+                        self._socket.send(datagram)
+                    except BlockingIOError:
+                        # we let QUIC handle any lossage
+                        pass
+                self._handle_events()
+        finally:
             with self._lock:
-                self._handle_timer(expiration)
-                datagrams = self._connection.datagrams_to_send(time.time())
-            for (datagram, _) in datagrams:
-                try:
-                    self._socket.send(datagram)
-                except BlockingIOError:
-                    # we let QUIC handle any lossage
-                    pass
-            self._handle_events()
+                self._done = True
+            # Ensure anyone waiting for this gets woken up.
+            self._handshake_complete.set()
 
     def _handle_events(self):
         while True:
@@ -166,6 +174,8 @@ class SyncQuicConnection(BaseQuicConnection):
     def make_stream(self):
         self._handshake_complete.wait()
         with self._lock:
+            if self._done:
+                raise UnexpectedEOF
             stream_id = self._connection.get_next_available_stream_id(False)
             stream = SyncQuicStream(self, stream_id)
             self._streams[stream_id] = stream
index 543e3cb5325fa0ffdfa420b4b074d7115abb2af9..7f81061c970ac550dc93da7c4909918fecc22053 100644 (file)
@@ -17,6 +17,7 @@ from dns.quic._common import (
     AsyncQuicConnection,
     AsyncQuicManager,
     QUIC_MAX_DATAGRAM,
+    UnexpectedEOF,
 )
 
 
@@ -80,20 +81,26 @@ class TrioQuicConnection(AsyncQuicConnection):
         self._worker_scope = None
 
     async def _worker(self):
-        await self._socket.connect(self._peer)
-        while not self._done:
-            (expiration, interval) = self._get_timer_values(False)
-            with trio.CancelScope(
-                deadline=trio.current_time() + interval
-            ) as self._worker_scope:
-                datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
-                self._connection.receive_datagram(datagram, self._peer[0], time.time())
-            self._worker_scope = None
-            self._handle_timer(expiration)
-            datagrams = self._connection.datagrams_to_send(time.time())
-            for (datagram, _) in datagrams:
-                await self._socket.send(datagram)
-            await self._handle_events()
+        try:
+            await self._socket.connect(self._peer)
+            while not self._done:
+                (expiration, interval) = self._get_timer_values(False)
+                with trio.CancelScope(
+                    deadline=trio.current_time() + interval
+                ) as self._worker_scope:
+                    datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
+                    self._connection.receive_datagram(
+                        datagram, self._peer[0], time.time()
+                    )
+                self._worker_scope = None
+                self._handle_timer(expiration)
+                datagrams = self._connection.datagrams_to_send(time.time())
+                for datagram, _ in datagrams:
+                    await self._socket.send(datagram)
+                await self._handle_events()
+        finally:
+            self._done = True
+            self._handshake_complete.set()
 
     async def _handle_events(self):
         count = 0
@@ -132,6 +139,8 @@ class TrioQuicConnection(AsyncQuicConnection):
 
     async def make_stream(self):
         await self._handshake_complete.wait()
+        if self._done:
+            raise UnexpectedEOF
         stream_id = self._connection.get_next_available_stream_id(False)
         stream = TrioQuicStream(self, stream_id)
         self._streams[stream_id] = stream