# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import contextlib
+import enum
import functools
import socket
import struct
import dns.rcode
import dns.trio.query
+class ConnectionType(enum.IntEnum):
+ UDP = 1
+ TCP = 2
class Server(threading.Thread):
finally:
raise EOFError
- def handle(self, message):
+ def handle(self, message, peer, connection_type):
#
# Handle message 'message'. Override this method to change
# how the server behaves.
except Exception:
return None
- def handle_wire(self, wire):
+ def handle_wire(self, wire, peer, connection_type):
#
# This is the common code to parse wire format, call handle() on
# the message, and then generate resposne wire format (if handle()
#
# Returns a wire format message to send, or None indicating there
# is nothing to do.
+ #
+ # XXXRTH It might be nice to have a "debug mode" in the server
+ # where we'd print something in all the places we're eating
+ # exceptions. That way bugs in handle() would be easier to
+ # find.
+ #
r = None
try:
q = dns.message.from_wire(wire)
# r might have been set above, so skip handle() if we
# already have a response.
if r is None:
- r = self.handle(q)
+ r = self.handle(q, peer, connection_type)
except Exception:
# Exceptions from handle get a SERVFAIL response.
r = dns.message.make_response(q)
self.udp = None # we own cleanup
while True:
try:
- (wire, from_address) = await sock.recvfrom(65535)
- wire = self.handle_wire(wire)
+ (wire, peer) = await sock.recvfrom(65535)
+ wire = self.handle_wire(wire, peer, ConnectionType.UDP)
if wire is not None:
- await sock.sendto(wire, from_address)
+ await sock.sendto(wire, peer)
except Exception:
pass
async def serve_tcp(self, stream):
try:
+ peer = stream.socket.getpeername()
while True:
ldata = await dns.trio.query.read_exactly(stream, 2)
(l,) = struct.unpack("!H", ldata)
wire = await dns.trio.query.read_exactly(stream, l)
- wire = self.handle_wire(wire)
+ wire = self.handle_wire(wire, peer, ConnectionType.TCP)
if wire is not None:
l = len(wire)
stream_message = struct.pack("!H", l) + wire