From 954856b07e53f135691e372e3bd30a98d5d33ce0 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sun, 22 Oct 2023 07:12:41 -0700 Subject: [PATCH] Fix two QUIC issues: 1) We treated stream reset like connection terminated, which is just wrong. We should send EOF to the stream but leave the connection alone. 2) When we got an unexpected EOF on a stream, we raised the exception in the wrong place, killing the QUIC connection but leaving the stream blocked. Now we deliver the exception to the stream and don't kill the connection. --- dns/quic/_asyncio.py | 9 ++++++--- dns/quic/_common.py | 5 ++++- dns/quic/_sync.py | 9 ++++++--- dns/quic/_trio.py | 8 +++++--- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/dns/quic/_asyncio.py b/dns/quic/_asyncio.py index e1c52339..b0574830 100644 --- a/dns/quic/_asyncio.py +++ b/dns/quic/_asyncio.py @@ -147,11 +147,14 @@ class AsyncioQuicConnection(AsyncQuicConnection): await stream._add_input(event.data, event.end_stream) elif isinstance(event, aioquic.quic.events.HandshakeCompleted): self._handshake_complete.set() - elif isinstance( - event, aioquic.quic.events.ConnectionTerminated - ) or isinstance(event, aioquic.quic.events.StreamReset): + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): self._done = True self._receiver_task.cancel() + elif isinstance(event, aioquic.quic.events.StreamReset): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(b"", True) + count += 1 if count > 10: # yield diff --git a/dns/quic/_common.py b/dns/quic/_common.py index 38ec103f..e4a9f18d 100644 --- a/dns/quic/_common.py +++ b/dns/quic/_common.py @@ -79,7 +79,10 @@ class BaseQuicStream: def _common_add_input(self, data, is_end): self._buffer.put(data, is_end) - return self._expecting > 0 and self._buffer.have(self._expecting) + try: + return self._expecting > 0 and self._buffer.have(self._expecting) + except UnexpectedEOF: + return True def _close(self): self._connection.close_stream(self._stream_id) diff --git a/dns/quic/_sync.py b/dns/quic/_sync.py index e944784d..6e13cad4 100644 --- a/dns/quic/_sync.py +++ b/dns/quic/_sync.py @@ -155,11 +155,14 @@ class SyncQuicConnection(BaseQuicConnection): stream._add_input(event.data, event.end_stream) elif isinstance(event, aioquic.quic.events.HandshakeCompleted): self._handshake_complete.set() - elif isinstance( - event, aioquic.quic.events.ConnectionTerminated - ) or isinstance(event, aioquic.quic.events.StreamReset): + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): with self._lock: self._done = True + elif isinstance(event, aioquic.quic.events.StreamReset): + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + stream._add_input(b"", True) def write(self, stream, data, is_end=False): with self._lock: diff --git a/dns/quic/_trio.py b/dns/quic/_trio.py index ee07e4f6..43c1b1a4 100644 --- a/dns/quic/_trio.py +++ b/dns/quic/_trio.py @@ -116,11 +116,13 @@ class TrioQuicConnection(AsyncQuicConnection): await stream._add_input(event.data, event.end_stream) elif isinstance(event, aioquic.quic.events.HandshakeCompleted): self._handshake_complete.set() - elif isinstance( - event, aioquic.quic.events.ConnectionTerminated - ) or isinstance(event, aioquic.quic.events.StreamReset): + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): self._done = True self._socket.close() + elif isinstance(event, aioquic.quic.events.StreamReset): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(b"", True) count += 1 if count > 10: # yield -- 2.47.3