"""
def __init__(self, address='127.0.0.1', port=0, enable_udp=True,
- enable_tcp=True, use_thread=True):
+ enable_tcp=True, use_thread=True, origin=None):
super().__init__()
self.address = address
self.port = port
self.enable_udp = enable_udp
self.enable_tcp = enable_tcp
self.use_thread = use_thread
+ self.origin = origin
self.left = None
self.right = None
self.udp = None
self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.tcp.bind((self.address, self.port))
self.tcp.listen()
- self.tcp_address = self.udp.getsockname()
+ self.tcp_address = self.tcp.getsockname()
if self.use_thread:
self.start()
return self
# how the server behaves.
#
# The return value is either a dns.message.Message, a bytes,
- # or None. We allow a bytes to be returned for cases where
- # handle wants to return an invalid DNS message for testing
- # purposes. We allow None to be returned to indicate there is
- # no response.
+ # None, or a list of one of those. We allow a bytes to be
+ # returned for cases where handle wants to return an invalid
+ # DNS message for testing purposes. We allow None to be
+ # returned to indicate there is no response. If a list is
+ # returned, then the output code will run for each returned
+ # item.
#
try:
r = dns.message.make_response(message)
except Exception:
return None
+ def maybe_listify(self, thing):
+ if isinstance(thing, list):
+ return thing
+ else:
+ return [thing]
+
+ def maybe_render(self, thing):
+ if isinstance(thing, dns.message.Message):
+ return thing.to_wire(self.origin)
+ else:
+ return thing
+
def handle_wire(self, wire, peer, connection_type):
#
# This is the common code to parse wire format, call handle() on
- # the message, and then generate resposne wire format (if handle()
+ # the message, and then generate response wire format (if handle()
# didn't do it).
#
# It also handles any exceptions from handle()
#
- # Returns a wire format message to send, or None indicating there
- # is nothing to do.
+ # Returns a (possibly empty) list of wire format message to send.
#
# XXXRTH It might be nice to have a "debug mode" in the server
# where we'd print something in all the places we're eating
# exceptions. That way bugs in handle() would be easier to
# find.
#
+ items = []
r = None
try:
q = dns.message.from_wire(wire)
except dns.message.ShortHeader:
# There is no hope of answering this one!
- return None
+ return []
except Exception:
# Try to make a FORMERR using just the question section.
try:
q = dns.message.from_wire(wire, question_only=True)
r = dns.message.make_response(q)
r.set_rcode(dns.rcode.FORMERR)
+ items.append(r)
except Exception:
# We could try to make a response from only the header
# if dnspython had a header_only option to
# from_wire(), or if we truncated wire outselves, but
# for now we just drop.
- return None
+ return []
try:
- # r might have been set above, so skip handle() if we
- # already have a response.
- if r is None:
- r = self.handle(q, peer, connection_type)
+ # items might have been appended to above, so skip
+ # handle() if we already have a response.
+ if not items:
+ items = self.maybe_listify(self.handle(q, peer,
+ connection_type))
except Exception:
# Exceptions from handle get a SERVFAIL response.
r = dns.message.make_response(q)
r.set_rcode(dns.rcode.SERVFAIL)
- if isinstance(r, dns.message.Message):
- wire = r.to_wire()
- else:
- wire = r
- return wire
+ items = [r]
+ return [self.maybe_render(x) for x in items]
async def serve_udp(self):
with trio.socket.from_stdlib_socket(self.udp) as sock:
while True:
try:
(wire, peer) = await sock.recvfrom(65535)
- wire = self.handle_wire(wire, peer, ConnectionType.UDP)
- if wire is not None:
+ for wire in self.handle_wire(wire, peer,
+ ConnectionType.UDP):
await sock.sendto(wire, peer)
except Exception:
pass
ldata = await read_exactly(stream, 2)
(l,) = struct.unpack("!H", ldata)
wire = await read_exactly(stream, l)
- wire = self.handle_wire(wire, peer, ConnectionType.TCP)
- if wire is not None:
+ for wire in self.handle_wire(wire, peer, ConnectionType.TCP):
l = len(wire)
stream_message = struct.pack("!H", l) + wire
await stream.send_all(stream_message)