peer_info = writer.get_extra_info("peername")
peer = Peer(peer_info[0], peer_info[1])
+ for _ in range(0, 1):
+ wire = await self._read_tcp_query(reader)
+ if not wire:
+ break
+ await self._send_tcp_response(writer, peer, wire)
+
+ writer.close()
+ await writer.wait_closed()
+
+ async def _read_tcp_query(self, reader: asyncio.StreamReader) -> Optional[bytes]:
+ wire_length = await self._read_tcp_query_wire_length(reader)
+ if not wire_length:
+ return None
+
+ return await self._read_tcp_query_wire(reader, wire_length)
+
+ async def _read_tcp_query_wire_length(
+ self, reader: asyncio.StreamReader
+ ) -> Optional[int]:
wire_length_bytes = await reader.read(2)
if len(wire_length_bytes) < 2:
- return
+ return None
+
(wire_length,) = struct.unpack("!H", wire_length_bytes)
+
+ return wire_length
+
+ async def _read_tcp_query_wire(
+ self, reader: asyncio.StreamReader, wire_length: int
+ ) -> Optional[bytes]:
logging.debug("Receiving TCP message (%d octets)...", wire_length)
wire = await reader.read(wire_length)
if len(wire) < wire_length:
- return
- full_message = wire_length_bytes + wire
- logging.debug("Received complete TCP message: %s", full_message.hex())
+ return None
+
+ logging.debug("Received complete TCP message: %s", wire.hex())
+
+ return wire
+ async def _send_tcp_response(
+ self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
+ ) -> None:
responses = self._handle_query(wire, peer, DnsProtocol.TCP)
async for response in responses:
writer.write(response)
logging.error("TCP connection from %s reset by peer", peer)
return
- writer.close()
- await writer.wait_closed()
-
def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
logging.info(
"Received %s/%s/%s (ID=%d) query from %s (%s)",