From 60253ac495078b1fdb30515e9d99b0bed017078c Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Wed, 12 Jul 2023 17:05:18 -0700 Subject: [PATCH] Fix a number of timeout bugs with QUIC [#954]. --- dns/asyncquery.py | 7 ++++--- dns/query.py | 6 +++--- dns/quic/_asyncio.py | 12 ++++++++---- dns/quic/_common.py | 4 ++-- dns/quic/_sync.py | 8 +++++--- dns/quic/_trio.py | 24 ++++++++++++++++-------- 6 files changed, 38 insertions(+), 23 deletions(-) diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 4e660b53..f503aace 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -43,6 +43,7 @@ from dns.query import ( _compute_times, _have_http2, _matches_destination, + _remaining, have_doh, ssl, ) @@ -736,11 +737,11 @@ async def quic( ) as the_manager: if not connection: the_connection = the_manager.connect(where, port, source, source_port) - start = time.time() - stream = await the_connection.make_stream() + (start, expiration) = _compute_times(timeout) + stream = await the_connection.make_stream(timeout) async with stream: await stream.send(wire, True) - wire = await stream.receive(timeout) + wire = await stream.receive(_remaining(expiration)) finish = time.time() r = dns.message.from_wire( wire, diff --git a/dns/query.py b/dns/query.py index 864c2e62..d49688dd 100644 --- a/dns/query.py +++ b/dns/query.py @@ -1186,10 +1186,10 @@ def quic( with manager: if not connection: the_connection = the_manager.connect(where, port, source, source_port) - start = time.time() - with the_connection.make_stream() as stream: + (start, expiration) = _compute_times(timeout) + with the_connection.make_stream(timeout) as stream: stream.send(wire, True) - wire = stream.receive(timeout) + wire = stream.receive(_remaining(expiration)) finish = time.time() r = dns.message.from_wire( wire, diff --git a/dns/quic/_asyncio.py b/dns/quic/_asyncio.py index b6be228f..f01ebc33 100644 --- a/dns/quic/_asyncio.py +++ b/dns/quic/_asyncio.py @@ -11,6 +11,7 @@ import aioquic.quic.connection # type: ignore import aioquic.quic.events # type: ignore import dns.asyncbackend +import dns.exception import dns.inet from dns.quic._common import ( QUIC_MAX_DATAGRAM, @@ -38,8 +39,8 @@ class AsyncioQuicStream(BaseQuicStream): self._expecting = amount try: await asyncio.wait_for(self._wait_for_wake_up(), timeout) - except Exception: - pass + except TimeoutError: + raise dns.exception.Timeout self._expecting = 0 async def receive(self, timeout=None): @@ -166,8 +167,11 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._receiver_task = asyncio.Task(self._receiver()) self._sender_task = asyncio.Task(self._sender()) - async def make_stream(self): - await self._handshake_complete.wait() + async def make_stream(self, timeout=None): + try: + await asyncio.wait_for(self._handshake_complete.wait(), timeout) + except TimeoutError: + raise dns.exception.Timeout if self._done: raise UnexpectedEOF stream_id = self._connection.get_next_available_stream_id(False) diff --git a/dns/quic/_common.py b/dns/quic/_common.py index b9717be3..38ec103f 100644 --- a/dns/quic/_common.py +++ b/dns/quic/_common.py @@ -3,7 +3,7 @@ import socket import struct import time -from typing import Any +from typing import Any, Optional import aioquic.quic.configuration # type: ignore import aioquic.quic.connection # type: ignore @@ -134,7 +134,7 @@ class BaseQuicConnection: class AsyncQuicConnection(BaseQuicConnection): - async def make_stream(self) -> Any: + async def make_stream(self, timeout: Optional[float] = None) -> Any: pass diff --git a/dns/quic/_sync.py b/dns/quic/_sync.py index 5d7df571..e944784d 100644 --- a/dns/quic/_sync.py +++ b/dns/quic/_sync.py @@ -11,6 +11,7 @@ import aioquic.quic.configuration # type: ignore import aioquic.quic.connection # type: ignore import aioquic.quic.events # type: ignore +import dns.exception import dns.inet from dns.quic._common import ( QUIC_MAX_DATAGRAM, @@ -42,7 +43,7 @@ class SyncQuicStream(BaseQuicStream): self._expecting = amount with self._wake_up: if not self._wake_up.wait(timeout): - raise TimeoutError + raise dns.exception.Timeout self._expecting = 0 def receive(self, timeout=None): @@ -171,8 +172,9 @@ class SyncQuicConnection(BaseQuicConnection): self._worker_thread = threading.Thread(target=self._worker) self._worker_thread.start() - def make_stream(self): - self._handshake_complete.wait() + def make_stream(self, timeout=None): + if not self._handshake_complete.wait(timeout): + raise dns.exception.Timeout with self._lock: if self._done: raise UnexpectedEOF diff --git a/dns/quic/_trio.py b/dns/quic/_trio.py index db73a902..ee07e4f6 100644 --- a/dns/quic/_trio.py +++ b/dns/quic/_trio.py @@ -10,6 +10,7 @@ import aioquic.quic.connection # type: ignore import aioquic.quic.events # type: ignore import trio +import dns.exception import dns.inet from dns._asyncbackend import NullContext from dns.quic._common import ( @@ -45,6 +46,7 @@ class TrioQuicStream(BaseQuicStream): (size,) = struct.unpack("!H", self._buffer.get(2)) await self.wait_for(size) return self._buffer.get(size) + raise dns.exception.Timeout async def send(self, datagram, is_end=False): data = self._encapsulate(datagram) @@ -137,14 +139,20 @@ class TrioQuicConnection(AsyncQuicConnection): nursery.start_soon(self._worker) self._run_done.set() - async def make_stream(self): - await self._handshake_complete.wait() - if self._done: - raise UnexpectedEOF - stream_id = self._connection.get_next_available_stream_id(False) - stream = TrioQuicStream(self, stream_id) - self._streams[stream_id] = stream - return stream + async def make_stream(self, timeout=None): + if timeout is None: + context = NullContext(None) + else: + context = trio.move_on_after(timeout) + with context: + await self._handshake_complete.wait() + if self._done: + raise UnexpectedEOF + stream_id = self._connection.get_next_available_stream_id(False) + stream = TrioQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + raise dns.exception.Timeout async def close(self): if not self._closed: -- 2.47.3