raise _ConnectionTeardownRequested
+class ConnectionHandler(abc.ABC):
+ """
+ Base class for TCP connection handlers.
+
+ An installed connection handler is called when a new TCP connection is
+ established. It may be used to perform arbitrary actions before
+ AsyncDnsServer processes DNS queries.
+ """
+
+ @abc.abstractmethod
+ async def handle(
+ self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer: Peer
+ ) -> None:
+ """
+ Handle the connection with the provided reader and writer.
+ """
+ raise NotImplementedError
+
+
class ResponseHandler(abc.ABC):
"""
Base class for generic response handlers.
super().__init__(self._handle_udp, self._handle_tcp, "ans.pid")
self._zone_tree: _ZoneTree = _ZoneTree()
+ self._connection_handler: Optional[ConnectionHandler] = None
self._response_handlers: List[ResponseHandler] = []
self._acknowledge_manual_dname_handling = acknowledge_manual_dname_handling
self._acknowledge_tsig_dnspython_hacks = acknowledge_tsig_dnspython_hacks
logging.info("Uninstalling response handler: %s", handler)
self._response_handlers.remove(handler)
+ def install_connection_handler(self, handler: ConnectionHandler) -> None:
+ """
+ Install a connection handler that will be called when a new TCP
+ connection is established.
+ """
+ if self._connection_handler:
+ raise RuntimeError("Only one connection handler can be installed")
+ self._connection_handler = handler
+
def _load_zones(self) -> None:
for entry in os.scandir():
entry_path = pathlib.Path(entry.path)
peer = Peer(peer_info[0], peer_info[1])
logging.debug("Accepted TCP connection from %s", peer)
- while True:
- try:
+ try:
+ if self._connection_handler:
+ await self._connection_handler.handle(reader, writer, peer)
+ while True:
wire = await self._read_tcp_query(reader, peer)
if not wire:
break
await self._send_tcp_response(writer, peer, wire)
- except _ConnectionTeardownRequested:
- break
- except ConnectionResetError:
- logging.error("TCP connection from %s reset by peer", peer)
- return
+ except _ConnectionTeardownRequested:
+ pass
+ except ConnectionResetError:
+ logging.error("TCP connection from %s reset by peer", peer)
+ return
logging.debug("Closing TCP connection from %s", peer)
writer.close()