]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Store server socket information in QueryContext
authorMichał Kępień <michal@isc.org>
Fri, 13 Feb 2026 13:27:10 +0000 (14:27 +0100)
committerMichał Kępień <michal@isc.org>
Fri, 13 Feb 2026 13:27:10 +0000 (14:27 +0100)
Extend the QueryContext class with a field holding the <address, port>
tuple for the socket on which a given query was received.  This will
enable query handlers to act upon that information in arbitrary ways.

bin/tests/system/isctest/asyncserver.py

index dd784ef5b65ce275707fbf8f5ccf905d5082d164..3e8e59080fa732bc02c4f5fc1fd5fe6b2ec3a55a 100644 (file)
@@ -266,6 +266,7 @@ class QueryContext:
 
     query: dns.message.Message
     response: dns.message.Message
+    socket: Peer
     peer: Peer
     protocol: DnsProtocol
     zone: Optional[dns.zone.Zone] = field(default=None, init=False)
@@ -1072,8 +1073,10 @@ class AsyncDnsServer(AsyncServer):
         self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport
     ) -> None:
         logging.debug("Received UDP message: %s", wire.hex())
+        socket_info = transport.get_extra_info("sockname")
+        socket = Peer(socket_info[0], socket_info[1])
         peer = Peer(addr[0], addr[1])
-        responses = self._handle_query(wire, peer, DnsProtocol.UDP)
+        responses = self._handle_query(wire, socket, peer, DnsProtocol.UDP)
         async for response in responses:
             logging.debug("Sending UDP message: %s", response.hex())
             transport.sendto(response, addr)
@@ -1170,7 +1173,9 @@ class AsyncDnsServer(AsyncServer):
     async def _send_tcp_response(
         self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
     ) -> None:
-        responses = self._handle_query(wire, peer, DnsProtocol.TCP)
+        socket_info = writer.get_extra_info("sockname")
+        socket = Peer(socket_info[0], socket_info[1])
+        responses = self._handle_query(wire, socket, peer, DnsProtocol.TCP)
         async for response in responses:
             logging.debug("Sending TCP response: %s", response.hex())
             writer.write(response)
@@ -1245,7 +1250,7 @@ class AsyncDnsServer(AsyncServer):
         logging.debug("[OUT] %s", response.hex())
 
     async def _handle_query(
-        self, wire: bytes, peer: Peer, protocol: DnsProtocol
+        self, wire: bytes, socket: Peer, peer: Peer, protocol: DnsProtocol
     ) -> AsyncGenerator[bytes, None]:
         """
         Yield wire data to send as a response over the established transport.
@@ -1256,7 +1261,7 @@ class AsyncDnsServer(AsyncServer):
             logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc)
             return
         response_stub = _make_asyncserver_response(query)
-        qctx = QueryContext(query, response_stub, peer, protocol)
+        qctx = QueryContext(query, response_stub, socket, peer, protocol)
         self._log_query(qctx, peer, protocol)
         responses = self._prepare_responses(qctx)
         async for response in responses: