import time
# Python2/3 compatibility hacks
-if sys.version_info[0] == 2:
+try:
+ from queue import Queue
+except ImportError:
from Queue import Queue
+
+try:
range = xrange
-else:
- from queue import Queue
- range = range # allow re-export of the builtin name
+except NameError:
+ pass
from recursortests import RecursorTest
-protobufQueue = Queue()
-protobufServerPort = 4243
-
def ProtobufConnectionHandler(queue, conn):
data = None
while True:
conn.close()
-def ProtobufListener(port):
- global protobufQueue
+def ProtobufListener(queue, port):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
try:
(conn, _) = sock.accept()
thread = threading.Thread(name='Connection Handler',
target=ProtobufConnectionHandler,
- args=[protobufQueue, conn])
+ args=[queue, conn])
thread.setDaemon(True)
thread.start()
sock.close()
-protobufListener = threading.Thread(name='Protobuf Listener', target=ProtobufListener, args=[protobufServerPort])
-protobufListener.setDaemon(True)
-protobufListener.start()
+class ProtobufServerParams:
+ def __init__(self, port):
+ self.queue = Queue()
+ self.port = port
+
+protobufServersParameters = [ProtobufServerParams(4243), ProtobufServerParams(4244)]
+protobufListeners = []
+for param in protobufServersParameters:
+ listener = threading.Thread(name='Protobuf Listener', target=ProtobufListener, args=[param.queue, param.port])
+ listener.setDaemon(True)
+ listener.start()
+ protobufListeners.append(listener)
class TestRecursorProtobuf(RecursorTest):
- global protobufServerPort
_lua_config_file = """
- protobufServer("127.0.0.1:%d")
- """ % (protobufServerPort)
+ protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"})
+ """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
def getFirstProtobufMessage(self, retries=1, waitTime=1):
- global protobufQueue
- failed = 0
-
- while protobufQueue.empty:
- if failed >= retries:
- break
-
- failed = failed + 1
- time.sleep(waitTime)
+ msg = None
+
+ print("in getFirstProtobufMessage")
+ for param in protobufServersParameters:
+ print(param.port)
+ failed = 0
+
+ while param.queue.empty:
+ print(failed)
+ print(retries)
+ if failed >= retries:
+ break
+
+ failed = failed + 1
+ print("waiting")
+ time.sleep(waitTime)
+
+ self.assertFalse(param.queue.empty())
+ data = param.queue.get(False)
+ self.assertTrue(data)
+ oldmsg = msg
+ msg = dnsmessage_pb2.PBDNSMessage()
+ msg.ParseFromString(data)
+ if oldmsg is not None:
+ self.assertEquals(msg, oldmsg)
- self.assertFalse(protobufQueue.empty())
- data = protobufQueue.get(False)
- self.assertTrue(data)
- msg = dnsmessage_pb2.PBDNSMessage()
- msg.ParseFromString(data)
return msg
def checkNoRemainingMessage(self):
- global protobufQueue
- self.assertTrue(protobufQueue.empty())
+ for param in protobufServersParameters:
+ self.assertTrue(param.queue.empty())
- def checkProtobufBase(self, msg, protocol, query, initiator, normalQueryResponse=True, expectedECS=None):
+ def checkProtobufBase(self, msg, protocol, query, initiator, normalQueryResponse=True, expectedECS=None, receivedSize=None):
self.assertTrue(msg)
self.assertTrue(msg.HasField('timeSec'))
self.assertTrue(msg.HasField('socketFamily'))
self.assertTrue(msg.HasField('socketProtocol'))
self.assertEquals(msg.socketProtocol, protocol)
self.assertTrue(msg.HasField('messageId'))
+ self.assertTrue(msg.HasField('serverIdentity'))
self.assertTrue(msg.HasField('id'))
self.assertEquals(msg.id, query.id)
self.assertTrue(msg.HasField('inBytes'))
if normalQueryResponse:
# compare inBytes with length of query/response
- self.assertEquals(msg.inBytes, len(query.to_wire()))
+ # Note that for responses, the size we received might differ
+ # because dnspython might compress labels differently from
+ # the recursor
+ if receivedSize:
+ self.assertEquals(msg.inBytes, receivedSize)
+ else:
+ self.assertEquals(msg.inBytes, len(query.to_wire()))
if expectedECS is not None:
self.assertTrue(msg.HasField('originalRequestorSubnet'))
# v4 only for now
self.assertTrue(msg.HasField('socketProtocol'))
self.assertEquals(msg.socketProtocol, protocol)
self.assertTrue(msg.HasField('messageId'))
+ self.assertTrue(msg.HasField('serverIdentity'))
self.assertTrue(msg.HasField('id'))
self.assertNotEquals(msg.id, query.id)
self.assertTrue(msg.HasField('inBytes'))
self.assertTrue(msg.question.HasField('qName'))
self.assertEquals(msg.question.qName, qname)
- def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'):
+ def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1', receivedSize=None):
self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
- self.checkProtobufBase(msg, protocol, response, initiator)
+ self.checkProtobufBase(msg, protocol, response, initiator, receivedSize=receivedSize)
self.assertTrue(msg.HasField('response'))
self.assertTrue(msg.response.HasField('queryTimeSec'))
- def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
+ def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl, checkTTL=True):
self.assertTrue(record.HasField('class'))
self.assertEquals(getattr(record, 'class'), rclass)
self.assertTrue(record.HasField('type'))
self.assertTrue(record.HasField('name'))
self.assertEquals(record.name, rname)
self.assertTrue(record.HasField('ttl'))
- self.assertEquals(record.ttl, rttl)
+ if checkTTL:
+ self.assertEquals(record.ttl, rttl)
self.assertTrue(record.HasField('rdata'))
def checkProtobufPolicy(self, msg, policyType, reason):
@classmethod
def setUpClass(cls):
- global protobufListener
- global protobufServerPort
- global ProtobufListener
- if protobufListener is None or not protobufListener.isAlive():
- protobufListener = threading.Thread(name='Protobuf Listener', target=ProtobufListener, args=[protobufServerPort])
- protobufListener.setDaemon(True)
- protobufListener.start()
-
cls.setUpSockets()
cls.startResponders()
def setUp(self):
# Make sure the queue is empty, in case
# a previous test failed
- global protobufQueue
- while not protobufQueue.empty():
- protobufQueue.get(False)
+ for param in protobufServersParameters:
+ while not param.queue.empty():
+ param.queue.get(False)
@classmethod
def generateRecursorConfig(cls, confdir):
tagged 3600 IN A 192.0.2.84
query-selected 3600 IN A 192.0.2.84
answer-selected 3600 IN A 192.0.2.84
+types 3600 IN A 192.0.2.84
+types 3600 IN AAAA 2001:DB8::1
+types 3600 IN TXT "Lorem ipsum dolor sit amet"
+types 3600 IN MX 10 a.example.
+types 3600 IN SPF "v=spf1 -all"
+types 3600 IN SRV 10 20 443 a.example.
+cname 3600 IN CNAME a.example.
+
""".format(soa=cls._SOA))
super(TestRecursorProtobuf, cls).generateRecursorConfig(confdir)
query = dns.message.make_query(name, 'A', want_dnssec=True)
query.flags |= dns.flags.CD
res = self.sendUDPQuery(query)
+
self.assertRRsetInAnswer(res, expected)
# check the protobuf messages corresponding to the UDP query and answer
self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
# then the response
msg = self.getFirstProtobufMessage()
- self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+ self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res, '127.0.0.1')
self.assertEquals(len(msg.response.rrs), 1)
rr = msg.response.rrs[0]
# we have max-cache-ttl set to 15
self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.42')
self.checkNoRemainingMessage()
+ def testCNAME(self):
+ name = 'cname.example.'
+ expectedCNAME = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'CNAME', 'a.example.')
+ expectedA = dns.rrset.from_text('a.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.42')
+ query = dns.message.make_query(name, 'A', want_dnssec=True)
+ query.flags |= dns.flags.CD
+ raw = self.sendUDPQuery(query, decode=False)
+ res = dns.message.from_wire(raw)
+ self.assertRRsetInAnswer(res, expectedCNAME)
+ self.assertRRsetInAnswer(res, expectedA)
+
+ # check the protobuf messages corresponding to the UDP query and answer
+ # but first let the protobuf messages the time to get there
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+ # then the response
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res, '127.0.0.1', receivedSize=len(raw))
+ self.assertEquals(len(msg.response.rrs), 2)
+ rr = msg.response.rrs[0]
+ # we don't want to check the TTL for the A record, it has been cached by the previous test
+ self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 15)
+ self.assertEquals(rr.rdata, 'a.example.')
+ rr = msg.response.rrs[1]
+ # we have max-cache-ttl set to 15
+ self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, 'a.example.', 15, checkTTL=False)
+ self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.42')
+ self.checkNoRemainingMessage()
+
class OutgoingProtobufDefaultTest(TestRecursorProtobuf):
"""
This test makes sure that we correctly export outgoing queries over protobuf.
_config_template = """
auth-zones=example=configs/%s/example.zone""" % _confdir
_lua_config_file = """
- outgoingProtobufServer("127.0.0.1:%d")
- """ % (protobufServerPort)
+ outgoingProtobufServer({"127.0.0.1:%d", "127.0.0.1:%d"})
+ """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
def testA(self):
name = 'www.example.org.'
_confdir = 'ProtobufMasks'
_config_template = """
auth-zones=example=configs/%s/example.zone""" % _confdir
- global protobufServerPort
_protobufMaskV4 = 4
_protobufMaskV6 = 128
_lua_config_file = """
- protobufServer("127.0.0.1:%d")
+ protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"})
setProtobufMasks(%d, %d)
- """ % (protobufServerPort, _protobufMaskV4, _protobufMaskV6)
+ """ % (protobufServersParameters[0].port, protobufServersParameters[1].port, _protobufMaskV4, _protobufMaskV6)
def testA(self):
name = 'a.example.'
_confdir = 'ProtobufQueriesOnly'
_config_template = """
auth-zones=example=configs/%s/example.zone""" % _confdir
- global protobufServerPort
_lua_config_file = """
- protobufServer("127.0.0.1:%d", { logQueries=true, logResponses=false } )
- """ % (protobufServerPort)
+ protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=true, logResponses=false } )
+ """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
def testA(self):
name = 'a.example.'
_confdir = 'ProtobufResponsesOnly'
_config_template = """
auth-zones=example=configs/%s/example.zone""" % _confdir
- global protobufServerPort
_lua_config_file = """
- protobufServer("127.0.0.1:%d", { logQueries=false, logResponses=true } )
- """ % (protobufServerPort)
+ protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=true } )
+ """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
def testA(self):
name = 'a.example.'
_confdir = 'ProtobufTaggedOnly'
_config_template = """
auth-zones=example=configs/%s/example.zone""" % _confdir
- global protobufServerPort
_lua_config_file = """
- protobufServer("127.0.0.1:%d", { logQueries=true, logResponses=true, taggedOnly=true } )
- """ % (protobufServerPort)
+ protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=true, logResponses=true, taggedOnly=true } )
+ """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
_tags = ['tag1', 'tag2']
_tag_from_gettag = 'tag-from-gettag'
_lua_dns_script_file = """
_confdir = 'ProtobufSelectedFromLua'
_config_template = """
auth-zones=example=configs/%s/example.zone""" % _confdir
- global protobufServerPort
_lua_config_file = """
- protobufServer("127.0.0.1:%d", { logQueries=false, logResponses=false } )
- """ % (protobufServerPort)
+ protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=false } )
+ """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
_lua_dns_script_file = """
local ffi = require("ffi")
self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
self.checkNoRemainingMessage()
+
+class ProtobufExportTypesTest(TestRecursorProtobuf):
+ """
+ This test makes sure that we correctly export other types than A, AAAA and CNAME over protobuf.
+ """
+
+ _confdir = 'ProtobufExportTypes'
+ _config_template = """
+auth-zones=example=configs/%s/example.zone""" % _confdir
+ _lua_config_file = """
+ protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { exportTypes={"AAAA", "MX", "SPF", "SRV", "TXT"} } )
+ """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
+
+ def testA(self):
+ name = 'types.example.'
+ expected = [dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.84'),
+ dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'AAAA', '2001:DB8::1'),
+ dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'MX', '10 a.example.'),
+ dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'SPF', '"v=spf1 -all"'),
+ dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'SRV', '10 20 443 a.example.'),
+ dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'TXT', '"Lorem ipsum dolor sit amet"'),
+ ]
+ query = dns.message.make_query(name, 'ANY', want_dnssec=True)
+ query.flags |= dns.flags.CD
+ raw = self.sendUDPQuery(query, decode=False)
+ res = dns.message.from_wire(raw)
+
+ for rrset in expected:
+ self.assertRRsetInAnswer(res, rrset)
+
+ # check the protobuf messages corresponding to the UDP query and answer
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+ # then the response
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res, '127.0.0.1', receivedSize=len(raw))
+ self.assertEquals(len(msg.response.rrs), 5)
+ for rr in msg.response.rrs:
+ self.assertTrue(rr.type in [dns.rdatatype.AAAA, dns.rdatatype.TXT, dns.rdatatype.MX, dns.rdatatype.SPF, dns.rdatatype.SRV])
+
+ if rr.type == dns.rdatatype.AAAA:
+ self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.AAAA, name, 15)
+ self.assertEquals(socket.inet_ntop(socket.AF_INET6, rr.rdata), '2001:db8::1')
+ elif rr.type == dns.rdatatype.TXT:
+ self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.TXT, name, 15)
+ self.assertEquals(rr.rdata, '"Lorem ipsum dolor sit amet"')
+ elif rr.type == dns.rdatatype.MX:
+ self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.MX, name, 15)
+ self.assertEquals(rr.rdata, 'a.example.')
+ elif rr.type == dns.rdatatype.SPF:
+ self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.SPF, name, 15)
+ self.assertEquals(rr.rdata, '"v=spf1 -all"')
+ elif rr.type == dns.rdatatype.SRV:
+ self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.SRV, name, 15)
+ self.assertEquals(rr.rdata, 'a.example.')
+
+ self.checkNoRemainingMessage()