]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Fix a number of timeout bugs with QUIC [#954].
authorBob Halley <halley@dnspython.org>
Thu, 13 Jul 2023 00:05:18 +0000 (17:05 -0700)
committerBob Halley <halley@dnspython.org>
Thu, 13 Jul 2023 00:05:18 +0000 (17:05 -0700)
dns/asyncquery.py
dns/query.py
dns/quic/_asyncio.py
dns/quic/_common.py
dns/quic/_sync.py
dns/quic/_trio.py

index 4e660b535332237848b016ec7dd3c864a07e2fda..f503aace7c899ae7bc59e364a4695c19339eb0b7 100644 (file)
@@ -43,6 +43,7 @@ from dns.query import (
     _compute_times,
     _have_http2,
     _matches_destination,
+    _remaining,
     have_doh,
     ssl,
 )
@@ -736,11 +737,11 @@ async def quic(
         ) as the_manager:
             if not connection:
                 the_connection = the_manager.connect(where, port, source, source_port)
-            start = time.time()
-            stream = await the_connection.make_stream()
+            (start, expiration) = _compute_times(timeout)
+            stream = await the_connection.make_stream(timeout)
             async with stream:
                 await stream.send(wire, True)
-                wire = await stream.receive(timeout)
+                wire = await stream.receive(_remaining(expiration))
             finish = time.time()
         r = dns.message.from_wire(
             wire,
index 864c2e62670db817c7d136fb885d4fba3dc185e4..d49688ddac28376875cddd663978d6bf541a9698 100644 (file)
@@ -1186,10 +1186,10 @@ def quic(
     with manager:
         if not connection:
             the_connection = the_manager.connect(where, port, source, source_port)
-        start = time.time()
-        with the_connection.make_stream() as stream:
+        (start, expiration) = _compute_times(timeout)
+        with the_connection.make_stream(timeout) as stream:
             stream.send(wire, True)
-            wire = stream.receive(timeout)
+            wire = stream.receive(_remaining(expiration))
         finish = time.time()
     r = dns.message.from_wire(
         wire,
index b6be228f64dd6d07fc6cb2b89cb108d83e9c70a8..f01ebc331a1a656aff7ace3553e889ff19f6cf95 100644 (file)
@@ -11,6 +11,7 @@ import aioquic.quic.connection  # type: ignore
 import aioquic.quic.events  # type: ignore
 
 import dns.asyncbackend
+import dns.exception
 import dns.inet
 from dns.quic._common import (
     QUIC_MAX_DATAGRAM,
@@ -38,8 +39,8 @@ class AsyncioQuicStream(BaseQuicStream):
             self._expecting = amount
             try:
                 await asyncio.wait_for(self._wait_for_wake_up(), timeout)
-            except Exception:
-                pass
+            except TimeoutError:
+                raise dns.exception.Timeout
             self._expecting = 0
 
     async def receive(self, timeout=None):
@@ -166,8 +167,11 @@ class AsyncioQuicConnection(AsyncQuicConnection):
         self._receiver_task = asyncio.Task(self._receiver())
         self._sender_task = asyncio.Task(self._sender())
 
-    async def make_stream(self):
-        await self._handshake_complete.wait()
+    async def make_stream(self, timeout=None):
+        try:
+            await asyncio.wait_for(self._handshake_complete.wait(), timeout)
+        except TimeoutError:
+            raise dns.exception.Timeout
         if self._done:
             raise UnexpectedEOF
         stream_id = self._connection.get_next_available_stream_id(False)
index b9717be374ebed7c6262677655d88e7c2453950e..38ec103ff8c04b0fa6da4b5e67c56f4fec989b11 100644 (file)
@@ -3,7 +3,7 @@
 import socket
 import struct
 import time
-from typing import Any
+from typing import Any, Optional
 
 import aioquic.quic.configuration  # type: ignore
 import aioquic.quic.connection  # type: ignore
@@ -134,7 +134,7 @@ class BaseQuicConnection:
 
 
 class AsyncQuicConnection(BaseQuicConnection):
-    async def make_stream(self) -> Any:
+    async def make_stream(self, timeout: Optional[float] = None) -> Any:
         pass
 
 
index 5d7df5716ce5853a7076a162bfa8e01f784c3f3b..e944784dee94ac3ac39ff27a48653d534b638068 100644 (file)
@@ -11,6 +11,7 @@ import aioquic.quic.configuration  # type: ignore
 import aioquic.quic.connection  # type: ignore
 import aioquic.quic.events  # type: ignore
 
+import dns.exception
 import dns.inet
 from dns.quic._common import (
     QUIC_MAX_DATAGRAM,
@@ -42,7 +43,7 @@ class SyncQuicStream(BaseQuicStream):
                 self._expecting = amount
             with self._wake_up:
                 if not self._wake_up.wait(timeout):
-                    raise TimeoutError
+                    raise dns.exception.Timeout
             self._expecting = 0
 
     def receive(self, timeout=None):
@@ -171,8 +172,9 @@ class SyncQuicConnection(BaseQuicConnection):
         self._worker_thread = threading.Thread(target=self._worker)
         self._worker_thread.start()
 
-    def make_stream(self):
-        self._handshake_complete.wait()
+    def make_stream(self, timeout=None):
+        if not self._handshake_complete.wait(timeout):
+            raise dns.exception.Timeout
         with self._lock:
             if self._done:
                 raise UnexpectedEOF
index db73a902de6ae12676021aa51bba6e77395eee29..ee07e4f6e8808fd8a134dd33a2c8e9f9e8146fea 100644 (file)
@@ -10,6 +10,7 @@ import aioquic.quic.connection  # type: ignore
 import aioquic.quic.events  # type: ignore
 import trio
 
+import dns.exception
 import dns.inet
 from dns._asyncbackend import NullContext
 from dns.quic._common import (
@@ -45,6 +46,7 @@ class TrioQuicStream(BaseQuicStream):
             (size,) = struct.unpack("!H", self._buffer.get(2))
             await self.wait_for(size)
             return self._buffer.get(size)
+        raise dns.exception.Timeout
 
     async def send(self, datagram, is_end=False):
         data = self._encapsulate(datagram)
@@ -137,14 +139,20 @@ class TrioQuicConnection(AsyncQuicConnection):
             nursery.start_soon(self._worker)
         self._run_done.set()
 
-    async def make_stream(self):
-        await self._handshake_complete.wait()
-        if self._done:
-            raise UnexpectedEOF
-        stream_id = self._connection.get_next_available_stream_id(False)
-        stream = TrioQuicStream(self, stream_id)
-        self._streams[stream_id] = stream
-        return stream
+    async def make_stream(self, timeout=None):
+        if timeout is None:
+            context = NullContext(None)
+        else:
+            context = trio.move_on_after(timeout)
+        with context:
+            await self._handshake_complete.wait()
+            if self._done:
+                raise UnexpectedEOF
+            stream_id = self._connection.get_next_available_stream_id(False)
+            stream = TrioQuicStream(self, stream_id)
+            self._streams[stream_id] = stream
+            return stream
+        raise dns.exception.Timeout
 
     async def close(self):
         if not self._closed: