]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Enable receiving chunked TCP DNS messages
authorMichał Kępień <michal@isc.org>
Tue, 18 Mar 2025 15:28:18 +0000 (16:28 +0100)
committerMichał Kępień <michal@isc.org>
Tue, 18 Mar 2025 15:28:18 +0000 (16:28 +0100)
A TCP DNS client may send its queries in chunks, causing
StreamReader.read() to return less data than previously declared by the
client as the DNS message length; even the two-octet DNS message length
itself may be split up into two single-octet transmissions.  Sending
data in chunks is valid client behavior that should not be treated as an
error.  Add a new helper method for reading TCP data in a loop, properly
distinguishing between chunked queries and client disconnections.  Use
the new method for reading all TCP data from clients.

bin/tests/system/isctest/asyncserver.py

index 902f436ccc480a22b62520f35901edadc8774fe7..0d8996e8e2cdc8c18799d3e374f40967cb2958cb 100644 (file)
@@ -570,8 +570,8 @@ class AsyncDnsServer(AsyncServer):
     ) -> Optional[int]:
         logging.debug("Receiving TCP message length from %s...", peer)
 
-        wire_length_bytes = await reader.read(2)
-        if len(wire_length_bytes) < 2:
+        wire_length_bytes = await self._read_tcp_octets(reader, peer, 2)
+        if not wire_length_bytes:
             return None
 
         (wire_length,) = struct.unpack("!H", wire_length_bytes)
@@ -583,14 +583,38 @@ class AsyncDnsServer(AsyncServer):
     ) -> Optional[bytes]:
         logging.debug("Receiving TCP message (%d octets) from %s...", wire_length, peer)
 
-        wire = await reader.read(wire_length)
-        if len(wire) < wire_length:
+        wire = await self._read_tcp_octets(reader, peer, wire_length)
+        if not wire:
             return None
 
         logging.debug("Received complete TCP message from %s: %s", peer, wire.hex())
 
         return wire
 
+    async def _read_tcp_octets(
+        self, reader: asyncio.StreamReader, peer: Peer, expected: int
+    ) -> Optional[bytes]:
+        buffer = b""
+
+        while len(buffer) < expected:
+            chunk = await reader.read(expected - len(buffer))
+            if not chunk:
+                if buffer:
+                    logging.debug(
+                        "Received short TCP message (%d octets) from %s: %s",
+                        len(buffer),
+                        peer,
+                        buffer.hex(),
+                    )
+                else:
+                    logging.debug("Received disconnect from %s", peer)
+                return None
+
+            logging.debug("Received %d TCP octets from %s", len(chunk), peer)
+            buffer += chunk
+
+        return buffer
+
     async def _send_tcp_response(
         self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
     ) -> None: