]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Fix two QUIC issues:
authorBob Halley <halley@dnspython.org>
Sun, 22 Oct 2023 14:12:41 +0000 (07:12 -0700)
committerBob Halley <halley@dnspython.org>
Sun, 22 Oct 2023 14:12:41 +0000 (07:12 -0700)
  1) We treated stream reset like connection terminated, which
     is just wrong.  We should send EOF to the stream but leave
     the connection alone.

  2) When we got an unexpected EOF on a stream, we raised the
     exception in the wrong place, killing the QUIC connection
     but leaving the stream blocked.  Now we deliver the exception
     to the stream and don't kill the connection.

dns/quic/_asyncio.py
dns/quic/_common.py
dns/quic/_sync.py
dns/quic/_trio.py

index e1c52339d30ca0332593254fbaaed677657e74f1..b05748309d07e7654a41e32d58b08f5a6ba5da81 100644 (file)
@@ -147,11 +147,14 @@ class AsyncioQuicConnection(AsyncQuicConnection):
                     await stream._add_input(event.data, event.end_stream)
             elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
                 self._handshake_complete.set()
-            elif isinstance(
-                event, aioquic.quic.events.ConnectionTerminated
-            ) or isinstance(event, aioquic.quic.events.StreamReset):
+            elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
                 self._done = True
                 self._receiver_task.cancel()
+            elif isinstance(event, aioquic.quic.events.StreamReset):
+                stream = self._streams.get(event.stream_id)
+                if stream:
+                    await stream._add_input(b"", True)
+
             count += 1
             if count > 10:
                 # yield
index 38ec103ff8c04b0fa6da4b5e67c56f4fec989b11..e4a9f18dbd180d688cb8dc7dd6c3fa4e092114f8 100644 (file)
@@ -79,7 +79,10 @@ class BaseQuicStream:
 
     def _common_add_input(self, data, is_end):
         self._buffer.put(data, is_end)
-        return self._expecting > 0 and self._buffer.have(self._expecting)
+        try:
+            return self._expecting > 0 and self._buffer.have(self._expecting)
+        except UnexpectedEOF:
+            return True
 
     def _close(self):
         self._connection.close_stream(self._stream_id)
index e944784dee94ac3ac39ff27a48653d534b638068..6e13cad49ab6aaa28363720700a844aff03c8989 100644 (file)
@@ -155,11 +155,14 @@ class SyncQuicConnection(BaseQuicConnection):
                     stream._add_input(event.data, event.end_stream)
             elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
                 self._handshake_complete.set()
-            elif isinstance(
-                event, aioquic.quic.events.ConnectionTerminated
-            ) or isinstance(event, aioquic.quic.events.StreamReset):
+            elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
                 with self._lock:
                     self._done = True
+            elif isinstance(event, aioquic.quic.events.StreamReset):
+                with self._lock:
+                    stream = self._streams.get(event.stream_id)
+                if stream:
+                    stream._add_input(b"", True)
 
     def write(self, stream, data, is_end=False):
         with self._lock:
index ee07e4f6e8808fd8a134dd33a2c8e9f9e8146fea..43c1b1a491db3de7585e023664d01c9f22d8d0c3 100644 (file)
@@ -116,11 +116,13 @@ class TrioQuicConnection(AsyncQuicConnection):
                     await stream._add_input(event.data, event.end_stream)
             elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
                 self._handshake_complete.set()
-            elif isinstance(
-                event, aioquic.quic.events.ConnectionTerminated
-            ) or isinstance(event, aioquic.quic.events.StreamReset):
+            elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
                 self._done = True
                 self._socket.close()
+            elif isinstance(event, aioquic.quic.events.StreamReset):
+                stream = self._streams.get(event.stream_id)
+                if stream:
+                    await stream._add_input(b"", True)
             count += 1
             if count > 10:
                 # yield