TCP = enum.auto()
+@dataclass(frozen=True)
+class Peer:
+ """
+ Pretty-printed connection endpoint.
+ """
+
+ host: str
+ port: int
+
+ def __str__(self) -> str:
+ host = f"[{self.host}]" if ":" in self.host else self.host
+ return f"{host}:{self.port}"
+
+
@dataclass
class QueryContext:
"""
query: dns.message.Message
response: dns.message.Message
- peer: Tuple[str, int]
+ peer: Peer
protocol: DnsProtocol
zone: Optional[dns.zone.Zone] = None
soa: Optional[dns.rrset.RRset] = None
self._zone_tree.add(zone)
async def _handle_udp(
- self, wire: bytes, peer: Tuple[str, int], transport: asyncio.DatagramTransport
+ self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport
) -> None:
logging.debug("Received UDP message: %s", wire.hex())
+ peer = Peer(addr[0], addr[1])
responses = self._handle_query(wire, peer, DnsProtocol.UDP)
async for response in responses:
- transport.sendto(response, peer)
+ transport.sendto(response, addr)
async def _handle_tcp(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
+ peer_info = writer.get_extra_info("peername")
+ peer = Peer(peer_info[0], peer_info[1])
+
wire_length_bytes = await reader.read(2)
(wire_length,) = struct.unpack("!H", wire_length_bytes)
logging.debug("Receiving TCP message (%d octets)...", wire_length)
full_message = wire_length_bytes + wire
logging.debug("Received complete TCP message: %s", full_message.hex())
- peer = writer.get_extra_info("peername")
responses = self._handle_query(wire, peer, DnsProtocol.TCP)
async for response in responses:
writer.write(response)
try:
await writer.drain()
except ConnectionResetError:
- logging.error(
- "TCP connection from %s reset by peer", self._format_peer(peer)
- )
+ logging.error("TCP connection from %s reset by peer", peer)
return
writer.close()
await writer.wait_closed()
- def _format_peer(self, peer: Tuple[str, int]) -> str:
- host = peer[0]
- port = peer[1]
- if "::" in host:
- host = f"[{host}]"
- return f"{host}:{port}"
-
- def _log_query(
- self, qctx: QueryContext, peer: Tuple[str, int], protocol: DnsProtocol
- ) -> None:
+ def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
logging.info(
"Received %s/%s/%s (ID=%d) query from %s (%s)",
qctx.qname.to_text(omit_final_dot=True),
dns.rdataclass.to_text(qctx.qclass),
dns.rdatatype.to_text(qctx.qtype),
qctx.query.id,
- self._format_peer(peer),
+ peer,
protocol.name,
)
logging.debug(
self,
qctx: QueryContext,
response: Optional[Union[dns.message.Message, bytes]],
- peer: Tuple[str, int],
+ peer: Peer,
protocol: DnsProtocol,
) -> None:
if not response:
logging.info(
"Not sending a response to query (ID=%d) from %s (%s)",
qctx.query.id,
- self._format_peer(peer),
+ peer,
protocol.name,
)
return
len(response.authority),
len(response.additional),
qctx.query.id,
- self._format_peer(peer),
+ peer,
protocol.name,
)
logging.debug(
"Sending response (%d bytes) to a query (ID=%d) from %s (%s)",
len(response),
qctx.query.id,
- self._format_peer(peer),
+ peer,
protocol.name,
)
logging.debug("[OUT] %s", response.hex())
async def _handle_query(
- self, wire: bytes, peer: Tuple[str, int], protocol: DnsProtocol
+ self, wire: bytes, peer: Peer, protocol: DnsProtocol
) -> AsyncGenerator[bytes, None]:
"""
Yield wire data to send as a response over the established transport.