]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Simplify peer address formatting
authorMichał Kępień <michal@isc.org>
Tue, 18 Mar 2025 15:28:18 +0000 (16:28 +0100)
committerMichał Kępień <michal@isc.org>
Tue, 18 Mar 2025 15:28:18 +0000 (16:28 +0100)
Add a helper class, Peer, which holds the <host, port> tuple of a
connection endpoint and gets pretty-printed when formatted as a string.
This enables passing instances of this new class directly to logging
functions, eliminating the need for the AsyncDnsServer._format_peer()
helper method.

bin/tests/system/isctest/asyncserver.py

index ab508b404ae898b7dd0a8dd4fc041f98f8b4c744..211d1402218d1f49fe0ca1743ea9db6f3cf9d903 100644 (file)
@@ -224,6 +224,20 @@ class DnsProtocol(enum.Enum):
     TCP = enum.auto()
 
 
+@dataclass(frozen=True)
+class Peer:
+    """
+    Pretty-printed connection endpoint.
+    """
+
+    host: str
+    port: int
+
+    def __str__(self) -> str:
+        host = f"[{self.host}]" if ":" in self.host else self.host
+        return f"{host}:{self.port}"
+
+
 @dataclass
 class QueryContext:
     """
@@ -232,7 +246,7 @@ class QueryContext:
 
     query: dns.message.Message
     response: dns.message.Message
-    peer: Tuple[str, int]
+    peer: Peer
     protocol: DnsProtocol
     zone: Optional[dns.zone.Zone] = None
     soa: Optional[dns.rrset.RRset] = None
@@ -513,16 +527,20 @@ class AsyncDnsServer(AsyncServer):
             self._zone_tree.add(zone)
 
     async def _handle_udp(
-        self, wire: bytes, peer: Tuple[str, int], transport: asyncio.DatagramTransport
+        self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport
     ) -> None:
         logging.debug("Received UDP message: %s", wire.hex())
+        peer = Peer(addr[0], addr[1])
         responses = self._handle_query(wire, peer, DnsProtocol.UDP)
         async for response in responses:
-            transport.sendto(response, peer)
+            transport.sendto(response, addr)
 
     async def _handle_tcp(
         self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
     ) -> None:
+        peer_info = writer.get_extra_info("peername")
+        peer = Peer(peer_info[0], peer_info[1])
+
         wire_length_bytes = await reader.read(2)
         (wire_length,) = struct.unpack("!H", wire_length_bytes)
         logging.debug("Receiving TCP message (%d octets)...", wire_length)
@@ -531,38 +549,26 @@ class AsyncDnsServer(AsyncServer):
         full_message = wire_length_bytes + wire
         logging.debug("Received complete TCP message: %s", full_message.hex())
 
-        peer = writer.get_extra_info("peername")
         responses = self._handle_query(wire, peer, DnsProtocol.TCP)
         async for response in responses:
             writer.write(response)
             try:
                 await writer.drain()
             except ConnectionResetError:
-                logging.error(
-                    "TCP connection from %s reset by peer", self._format_peer(peer)
-                )
+                logging.error("TCP connection from %s reset by peer", peer)
                 return
 
         writer.close()
         await writer.wait_closed()
 
-    def _format_peer(self, peer: Tuple[str, int]) -> str:
-        host = peer[0]
-        port = peer[1]
-        if "::" in host:
-            host = f"[{host}]"
-        return f"{host}:{port}"
-
-    def _log_query(
-        self, qctx: QueryContext, peer: Tuple[str, int], protocol: DnsProtocol
-    ) -> None:
+    def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
         logging.info(
             "Received %s/%s/%s (ID=%d) query from %s (%s)",
             qctx.qname.to_text(omit_final_dot=True),
             dns.rdataclass.to_text(qctx.qclass),
             dns.rdatatype.to_text(qctx.qtype),
             qctx.query.id,
-            self._format_peer(peer),
+            peer,
             protocol.name,
         )
         logging.debug(
@@ -573,14 +579,14 @@ class AsyncDnsServer(AsyncServer):
         self,
         qctx: QueryContext,
         response: Optional[Union[dns.message.Message, bytes]],
-        peer: Tuple[str, int],
+        peer: Peer,
         protocol: DnsProtocol,
     ) -> None:
         if not response:
             logging.info(
                 "Not sending a response to query (ID=%d) from %s (%s)",
                 qctx.query.id,
-                self._format_peer(peer),
+                peer,
                 protocol.name,
             )
             return
@@ -606,7 +612,7 @@ class AsyncDnsServer(AsyncServer):
                 len(response.authority),
                 len(response.additional),
                 qctx.query.id,
-                self._format_peer(peer),
+                peer,
                 protocol.name,
             )
             logging.debug(
@@ -618,13 +624,13 @@ class AsyncDnsServer(AsyncServer):
             "Sending response (%d bytes) to a query (ID=%d) from %s (%s)",
             len(response),
             qctx.query.id,
-            self._format_peer(peer),
+            peer,
             protocol.name,
         )
         logging.debug("[OUT] %s", response.hex())
 
     async def _handle_query(
-        self, wire: bytes, peer: Tuple[str, int], protocol: DnsProtocol
+        self, wire: bytes, peer: Peer, protocol: DnsProtocol
     ) -> AsyncGenerator[bytes, None]:
         """
         Yield wire data to send as a response over the established transport.