import contextlib
import enum
+import errno
import functools
import socket
import struct
UDP = 1
TCP = 2
+class Request:
+ def __init__(self, message, wire, peer, local, connection_type):
+ self.message = message
+ self.wire = wire
+ self.peer = peer
+ self.local = local
+ self.connection_type = connection_type
+
+ @property
+ def question(self):
+ return self.message.question[0]
+
+ @property
+ def qname(self):
+ return self.question.name
+
+ @property
+ def qclass(self):
+ return self.question.rdclass
+
+ @property
+ def qtype(self):
+ return self.question.rdtype
+
class Server(threading.Thread):
"""The nanoserver is a nameserver skeleton suitable for faking a DNS
self.tcp = None
self.tcp_address = None
- def __enter__(self):
- (self.left, self.right) = socket.socketpair()
- # We're making the UDP socket now so it can be sent to by the
- # caller immediately (i.e. no race with the listener starting
- # in the thread).
+ def _open_sockets(self):
if self.enable_udp:
self.udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
self.udp.bind((self.address, self.port))
if self.enable_tcp:
self.tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- self.tcp.bind((self.address, self.port))
+ port = self.port
+ if port is 0 and self.enable_udp:
+ port = self.udp_address[1]
+ port = 12347
+ self.tcp.bind((self.address, port))
self.tcp.listen()
self.tcp_address = self.tcp.getsockname()
+
+ def __enter__(self):
+ (self.left, self.right) = socket.socketpair()
+ # We're making the sockets now so they can be sent to by the
+ # caller immediately (i.e. no race with the listener starting
+ # in the thread).
+ open_udp_sockets = []
+ while True:
+ if self.enable_udp:
+ self.udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
+ self.udp.bind((self.address, self.port))
+ self.udp_address = self.udp.getsockname()
+ if self.enable_tcp:
+ self.tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+ self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ if self.port is 0 and self.enable_udp:
+ try:
+ self.tcp.bind((self.address, self.udp_address[1]))
+ except OSError as e:
+ if e.errno == errno.EADDRINUSE and \
+ len(open_udp_sockets) < 100:
+ open_udp_sockets.append(self.udp)
+ continue
+ raise
+ else:
+ self.tcp.bind((self.address, self.port))
+ self.tcp.listen()
+ self.tcp_address = self.tcp.getsockname()
+ break
if self.use_thread:
self.start()
return self
else:
return [thing]
- def handle_wire(self, wire, peer, connection_type):
+ def handle_wire(self, wire, peer, local, connection_type):
#
# This is the common code to parse wire format, call handle() on
# the message, and then generate response wire format (if handle()
# items might have been appended to above, so skip
# handle() if we already have a response.
if not items:
- items = self.maybe_listify(self.handle(q, peer,
- connection_type))
+ request = Request(q, wire, peer, local, connection_type)
+ items = self.maybe_listify(self.handle(request))
except Exception:
# Exceptions from handle get a SERVFAIL response.
r = dns.message.make_response(q)
async def serve_udp(self):
with trio.socket.from_stdlib_socket(self.udp) as sock:
self.udp = None # we own cleanup
+ local = self.udp_address
while True:
try:
(wire, peer) = await sock.recvfrom(65535)
- for wire in self.handle_wire(wire, peer,
+ for wire in self.handle_wire(wire, peer, local,
ConnectionType.UDP):
await sock.sendto(wire, peer)
except Exception:
async def serve_tcp(self, stream):
try:
peer = stream.socket.getpeername()
+ local = stream.socket.getsockname()
while True:
ldata = await read_exactly(stream, 2)
(l,) = struct.unpack("!H", ldata)
wire = await read_exactly(stream, l)
- for wire in self.handle_wire(wire, peer, ConnectionType.TCP):
+ for wire in self.handle_wire(wire, peer, local,
+ ConnectionType.TCP):
l = len(wire)
stream_message = struct.pack("!H", l) + wire
await stream.send_all(stream_message)
class AXFRNanoNameserver(Server):
- def handle(self, message, peer, connection_type):
+ def handle(self, request):
self.zone = dns.zone.from_text(axfr_zone, origin='example')
self.origin = self.zone.origin
items = []
soa = self.zone.find_rrset(dns.name.empty, dns.rdatatype.SOA)
- response = dns.message.make_response(message)
+ response = dns.message.make_response(request.message)
response.flags |= dns.flags.AA
response.answer.append(soa)
items.append(response)
- response = dns.message.make_response(message)
+ response = dns.message.make_response(request.message)
response.question = []
response.flags |= dns.flags.AA
for (name, rdataset) in self.zone.iterate_rdatasets():
rrset.update(rdataset)
response.answer.append(rrset)
items.append(response)
- response = dns.message.make_response(message)
+ response = dns.message.make_response(request.message)
response.question = []
response.flags |= dns.flags.AA
response.answer.append(soa)
super().__init__()
self.response_text = response_text
- def handle(self, message, peer, connection_type):
+ def handle(self, request):
try:
r = dns.message.from_text(self.response_text, one_rr_per_rrset=True)
- r.id = message.id
+ r.id = request.message.id
return r
except Exception:
pass
class TSIGNanoNameserver(Server):
- def handle(self, message, peer, connection_type):
- response = dns.message.make_response(message)
+ def handle(self, request):
+ response = dns.message.make_response(request.message)
response.set_rcode(dns.rcode.REFUSED)
response.flags |= dns.flags.RA
try:
- if message.question[0].rdtype == dns.rdatatype.A and \
- message.question[0].rdclass == dns.rdataclass.IN:
- rrs = dns.rrset.from_text(message.question[0].name, 300,
+ if request.qtype == dns.rdatatype.A and \
+ request.qclass == dns.rdataclass.IN:
+ rrs = dns.rrset.from_text(request.qname, 300,
'IN', 'A', '1.2.3.4')
response.answer.append(rrs)
response.set_rcode(dns.rcode.NOERROR)
class NaptrNanoNameserver(Server):
- def handle(self, message, peer, connection_type):
- response = dns.message.make_response(message)
+ def handle(self, request):
+ response = dns.message.make_response(request.message)
response.set_rcode(dns.rcode.REFUSED)
response.flags |= dns.flags.RA
try:
zero_subdomain = dns.e164.from_e164('0')
- if message.question[0].name.is_subdomain(zero_subdomain):
+ if request.qname.is_subdomain(zero_subdomain):
response.set_rcode(dns.rcode.NXDOMAIN)
response.flags |= dns.flags.AA
- elif message.question[0].rdtype == dns.rdatatype.NAPTR and \
- message.question[0].rdclass == dns.rdataclass.IN:
- rrs = dns.rrset.from_text(message.question[0].name, 300,
- 'IN', 'NAPTR',
+ elif request.qtype == dns.rdatatype.NAPTR and \
+ request.qclass == dns.rdataclass.IN:
+ rrs = dns.rrset.from_text(request.qname, 300, 'IN', 'NAPTR',
'0 0 "" "" "" .')
response.answer.append(rrs)
response.set_rcode(dns.rcode.NOERROR)