]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Refactor AsyncDnsServer._handle_tcp()
authorMichał Kępień <michal@isc.org>
Tue, 18 Mar 2025 15:28:18 +0000 (16:28 +0100)
committerMichał Kępień <michal@isc.org>
Tue, 18 Mar 2025 15:28:18 +0000 (16:28 +0100)
Split up AsyncDnsServer._handle_tcp() into a set of smaller methods to
improve code readability.

bin/tests/system/isctest/asyncserver.py

index 952cb797563400042af379dfc3ffffe78902b08c..7a3285402cab1db59596a6c41127bfd9599a5cbb 100644 (file)
@@ -541,18 +541,49 @@ class AsyncDnsServer(AsyncServer):
         peer_info = writer.get_extra_info("peername")
         peer = Peer(peer_info[0], peer_info[1])
 
+        for _ in range(0, 1):
+            wire = await self._read_tcp_query(reader)
+            if not wire:
+                break
+            await self._send_tcp_response(writer, peer, wire)
+
+        writer.close()
+        await writer.wait_closed()
+
+    async def _read_tcp_query(self, reader: asyncio.StreamReader) -> Optional[bytes]:
+        wire_length = await self._read_tcp_query_wire_length(reader)
+        if not wire_length:
+            return None
+
+        return await self._read_tcp_query_wire(reader, wire_length)
+
+    async def _read_tcp_query_wire_length(
+        self, reader: asyncio.StreamReader
+    ) -> Optional[int]:
         wire_length_bytes = await reader.read(2)
         if len(wire_length_bytes) < 2:
-            return
+            return None
+
         (wire_length,) = struct.unpack("!H", wire_length_bytes)
+
+        return wire_length
+
+    async def _read_tcp_query_wire(
+        self, reader: asyncio.StreamReader, wire_length: int
+    ) -> Optional[bytes]:
         logging.debug("Receiving TCP message (%d octets)...", wire_length)
 
         wire = await reader.read(wire_length)
         if len(wire) < wire_length:
-            return
-        full_message = wire_length_bytes + wire
-        logging.debug("Received complete TCP message: %s", full_message.hex())
+            return None
+
+        logging.debug("Received complete TCP message: %s", wire.hex())
+
+        return wire
 
+    async def _send_tcp_response(
+        self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
+    ) -> None:
         responses = self._handle_query(wire, peer, DnsProtocol.TCP)
         async for response in responses:
             writer.write(response)
@@ -562,9 +593,6 @@ class AsyncDnsServer(AsyncServer):
                 logging.error("TCP connection from %s reset by peer", peer)
                 return
 
-        writer.close()
-        await writer.wait_closed()
-
     def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
         logging.info(
             "Received %s/%s/%s (ID=%d) query from %s (%s)",