) -> Optional[int]:
logging.debug("Receiving TCP message length from %s...", peer)
- wire_length_bytes = await reader.read(2)
- if len(wire_length_bytes) < 2:
+ wire_length_bytes = await self._read_tcp_octets(reader, peer, 2)
+ if not wire_length_bytes:
return None
(wire_length,) = struct.unpack("!H", wire_length_bytes)
) -> Optional[bytes]:
logging.debug("Receiving TCP message (%d octets) from %s...", wire_length, peer)
- wire = await reader.read(wire_length)
- if len(wire) < wire_length:
+ wire = await self._read_tcp_octets(reader, peer, wire_length)
+ if not wire:
return None
logging.debug("Received complete TCP message from %s: %s", peer, wire.hex())
return wire
+ async def _read_tcp_octets(
+ self, reader: asyncio.StreamReader, peer: Peer, expected: int
+ ) -> Optional[bytes]:
+ buffer = b""
+
+ while len(buffer) < expected:
+ chunk = await reader.read(expected - len(buffer))
+ if not chunk:
+ if buffer:
+ logging.debug(
+ "Received short TCP message (%d octets) from %s: %s",
+ len(buffer),
+ peer,
+ buffer.hex(),
+ )
+ else:
+ logging.debug("Received disconnect from %s", peer)
+ return None
+
+ logging.debug("Received %d TCP octets from %s", len(chunk), peer)
+ buffer += chunk
+
+ return buffer
+
async def _send_tcp_response(
self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
) -> None: