import threading
import trio
+import dns.asyncquery
import dns.message
import dns.rcode
-import dns.trio.query
+
+async def read_exactly(stream, count):
+ """Read the specified number of bytes from stream. Keep trying until we
+ either get the desired amount, or we hit EOF.
+ """
+ s = b''
+ while count > 0:
+ n = await stream.receive_some(count)
+ if n == b'':
+ raise EOFError
+ count = count - len(n)
+ s = s + n
+ return s
class ConnectionType(enum.IntEnum):
UDP = 1
try:
peer = stream.socket.getpeername()
while True:
- ldata = await dns.trio.query.read_exactly(stream, 2)
+ ldata = await read_exactly(stream, 2)
(l,) = struct.unpack("!H", ldata)
- wire = await dns.trio.query.read_exactly(stream, l)
+ wire = await read_exactly(stream, l)
wire = self.handle_wire(wire, peer, ConnectionType.TCP)
if wire is not None:
l = len(wire)