]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Use queue instead of single future in asyncio datagram socket (#1250)
authordarkdragon-001 <darkdragon-001@users.noreply.github.com>
Fri, 2 Jan 2026 21:14:52 +0000 (22:14 +0100)
committerGitHub <noreply@github.com>
Fri, 2 Jan 2026 21:14:52 +0000 (13:14 -0800)
Previously, a single future was used and any package which was received before the future was awaited was silently discarded.

dns/_asyncio_backend.py

index e0c367ed307a355ad0fc998426bb8fdd65c9af0a..ef0bbf5e8b891ab37649de5d83dd28a07d2be899 100644 (file)
@@ -24,29 +24,22 @@ def _get_running_loop():
 class _DatagramProtocol(asyncio.DatagramProtocol):
     def __init__(self):
         self.transport = None
-        self.recvfrom = None
+        self.recvq = asyncio.Queue()
 
     def connection_made(self, transport):
         self.transport = transport
 
     def datagram_received(self, data, addr):
-        if self.recvfrom and not self.recvfrom.done():
-            self.recvfrom.set_result((data, addr))
+        self.recvq.put_nowait((data, addr))
 
-    def error_received(self, exc):  # pragma: no cover
-        if self.recvfrom and not self.recvfrom.done():
-            self.recvfrom.set_exception(exc)
+    def error_received(self, exc):
+        self.recvq.put_nowait(exc)
 
     def connection_lost(self, exc):
-        if self.recvfrom and not self.recvfrom.done():
-            if exc is None:
-                # EOF we triggered.  Is there a better way to do this?
-                try:
-                    raise EOFError("EOF")
-                except EOFError as e:
-                    self.recvfrom.set_exception(e)
-            else:
-                self.recvfrom.set_exception(exc)
+        if exc is None:
+            self.recvq.put_nowait(EOFError("EOF"))
+        else:
+            self.recvq.put_nowait(exc)
 
     def close(self):
         if self.transport is not None:
@@ -76,14 +69,10 @@ class _DatagramSocket(dns._asyncbackend.DatagramSocket):
 
     async def recvfrom(self, size, timeout):
         # ignore size as there's no way I know to tell protocol about it
-        done = _get_running_loop().create_future()
-        try:
-            assert self.protocol.recvfrom is None
-            self.protocol.recvfrom = done
-            await _maybe_wait_for(done, timeout)
-            return done.result()
-        finally:
-            self.protocol.recvfrom = None
+        pkg = await _maybe_wait_for(self.protocol.recvq.get(), timeout)
+        if isinstance(pkg, BaseException):
+            raise pkg
+        return pkg
 
     async def close(self):
         self.protocol.close()