import sys
import threading
import time
+import clientsubnetoption
# Python2/3 compatibility hacks
try:
def getFirstProtobufMessage(self, retries=1, waitTime=1):
msg = None
- print("in getFirstProtobufMessage")
+ #print("in getFirstProtobufMessage")
for param in protobufServersParameters:
- print(param.port)
+ #print(param.port)
failed = 0
- while param.queue.empty:
- print(failed)
- print(retries)
+ while param.queue.empty():
+ #print(failed)
+ #print(retries)
if failed >= retries:
break
failed = failed + 1
- print("waiting")
+ #print("waiting")
time.sleep(waitTime)
self.assertFalse(param.queue.empty())
if oldmsg is not None:
self.assertEqual(msg, oldmsg)
- print(msg)
+ #print(msg)
return msg
+ def emptyProtoBufQueue(self):
+ for param in protobufServersParameters:
+ while not param.queue.empty():
+ param.queue.get(False)
+
def checkNoRemainingMessage(self):
for param in protobufServersParameters:
self.assertTrue(param.queue.empty())
self.assertEqual(len(msg.originalRequestorSubnet), 4)
self.assertEqual(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
- def checkOutgoingProtobufBase(self, msg, protocol, query, initiator, length=None):
+ def checkOutgoingProtobufBase(self, msg, protocol, query, initiator, length=None, expectedECS=None):
self.assertTrue(msg)
self.assertTrue(msg.HasField('timeSec'))
self.assertTrue(msg.HasField('socketFamily'))
else:
# compare inBytes with length of query/response
self.assertEqual(msg.inBytes, len(query.to_wire()))
+ if expectedECS is not None:
+ self.assertTrue(msg.HasField('originalRequestorSubnet'))
+ # v4 only for now
+ self.assertEqual(len(msg.originalRequestorSubnet), 4)
+ self.assertEqual(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), expectedECS)
def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1', to='127.0.0.1'):
self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
self.assertEqual(msg.response.appliedPolicyKind, kind)
def checkProtobufTags(self, msg, tags):
- print(tags)
- print('---')
- print(msg.response.tags)
+ #print(tags)
+ #print('---')
+ #print(msg.response.tags)
self.assertEqual(len(msg.response.tags), len(tags))
for tag in msg.response.tags:
self.assertTrue(tag in tags)
def checkProtobufMetas(self, msg, metas):
- print(metas)
- print('---')
- print(msg.meta)
+ #print(metas)
+ #print('---')
+ #print(msg.meta)
self.assertEqual(len(msg.meta), len(metas))
for m in msg.meta:
self.assertTrue(m.HasField('key'))
for s in m.value.stringVal :
self.assertTrue(s in metas[m.key]['stringVal'])
- def checkProtobufOutgoingQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1', length=None):
+ def checkProtobufOutgoingQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1', length=None, expectedECS=None):
self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSOutgoingQueryType)
- self.checkOutgoingProtobufBase(msg, protocol, query, initiator, length=length)
+ self.checkOutgoingProtobufBase(msg, protocol, query, initiator, length=length, expectedECS=expectedECS)
self.assertTrue(msg.HasField('to'))
self.assertTrue(msg.HasField('question'))
self.assertTrue(msg.question.HasField('qClass'))
self.assertEqual(msg.response.rcode, 65536)
def checkProtobufIdentity(self, msg, requestorId, deviceId, deviceName):
- print(msg)
+ #print(msg)
self.assertTrue((requestorId == '') == (not msg.HasField('requestorId')))
self.assertTrue((deviceId == b'') == (not msg.HasField('deviceId')))
self.assertTrue((deviceName == '') == (not msg.HasField('deviceName')))
super(TestRecursorProtobuf, self).setUp()
# Make sure the queue is empty, in case
# a previous test failed
- for param in protobufServersParameters:
- while not param.queue.empty():
- param.queue.get(False)
+ self.emptyProtoBufQueue()
# wait long enough to be sure that the housekeeping has
# prime the root NS
time.sleep(1)
# Switch off QName Minimization, it generates much more protobuf messages
# (or make the test much more smart!)
qname-minimization=no
+ max-cache-ttl=600
+ loglevel=9
"""
_lua_config_file = """
outgoingProtobufServer({"127.0.0.1:%d", "127.0.0.1:%d"})
""" % (protobufServersParameters[0].port, protobufServersParameters[1].port)
def testA(self):
+ # There is a race in priming (having the . DNSKEY in cache in particular) and this code.
+ # So make sure we have the . DNSKEY in cache
+ query = dns.message.make_query('.', 'A', want_dnssec=True)
+ query.flags |= dns.flags.RD
+ res = self.sendUDPQuery(query)
+ time.sleep(1)
+ self.emptyProtoBufQueue()
+
name = 'host1.secure.example.'
expected = list()
- # the root DNSKEY has been learned with priming the root NS already
- # ('.', dns.rdatatype.DNSKEY, dnsmessage_pb2.PBDNSMessage.UDP, 201),
for qname, qtype, proto, responseSize in [
('host1.secure.example.', dns.rdatatype.A, dnsmessage_pb2.PBDNSMessage.UDP, 248),
('host1.secure.example.', dns.rdatatype.A, dnsmessage_pb2.PBDNSMessage.UDP, 221),
self.checkNoRemainingMessage()
+class OutgoingProtobufWithECSMappingTest(TestRecursorProtobuf):
+ """
+ This test makes sure that we correctly export outgoing queries over protobuf.
+ It must be improved and setup env so we can check for incoming responses, but makes sure for now
+ that the recursor at least connects to the protobuf server.
+ """
+
+ _confdir = 'OutgoingProtobuffWithECSMapping'
+ _config_template = """
+ # Switch off QName Minimization, it generates much more protobuf messages
+ # (or make the test much more smart!)
+ qname-minimization=no
+ edns-subnet-allow-list=example
+ allow-from=1.2.3.4/32
+ # this is to not let . queries interfere
+ max-cache-ttl=600
+ loglevel=9
+"""
+ _lua_config_file = """
+ outgoingProtobufServer({"127.0.0.1:%d", "127.0.0.1:%d"})
+ addProxyMapping("127.0.0.0/8", "1.2.3.4", { "host1.secure.example." })
+ """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
+
+ def testA(self):
+ # There is a race in priming (having the . DNSKEY in cache in particular) and this code.
+ # So make sure we have the . DNSKEY in cache
+ query = dns.message.make_query('.', 'A', want_dnssec=True)
+ query.flags |= dns.flags.RD
+ res = self.sendUDPQuery(query)
+ time.sleep(1)
+ self.emptyProtoBufQueue()
+
+ name = 'host1.secure.example.'
+ expected = list()
+
+ for qname, qtype, proto, responseSize, ecs in [
+ ('host1.secure.example.', dns.rdatatype.A, dnsmessage_pb2.PBDNSMessage.UDP, 248, "1.2.3.0"),
+ ('host1.secure.example.', dns.rdatatype.A, dnsmessage_pb2.PBDNSMessage.UDP, 221, "1.2.3.0"),
+ ('example.', dns.rdatatype.DNSKEY, dnsmessage_pb2.PBDNSMessage.UDP, 219, "1.2.3.0"),
+ ('host1.secure.example.', dns.rdatatype.A, dnsmessage_pb2.PBDNSMessage.UDP, 175, "1.2.3.0"),
+ ('secure.example.', dns.rdatatype.DNSKEY, dnsmessage_pb2.PBDNSMessage.UDP, 233, "1.2.3.0"),
+ ]:
+ if not qname:
+ expected.append((None, None, None, None, None, None, None))
+ continue
+ ecso = clientsubnetoption.ClientSubnetOption('9.10.11.12', 24)
+ query = dns.message.make_query(qname, qtype, use_edns=True, want_dnssec=True, options=[ecso], payload=512)
+ resp = dns.message.make_response(query)
+ expected.append((
+ qname, qtype, query, resp, proto, responseSize, ecs
+ ))
+
+ query = dns.message.make_query(name, 'A', want_dnssec=True)
+ query.flags |= dns.flags.RD
+ res = self.sendUDPQuery(query)
+
+ for qname, qtype, qry, ans, proto, responseSize, ecs in expected:
+ if not qname:
+ self.getFirstProtobufMessage()
+ self.getFirstProtobufMessage()
+ continue
+
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufOutgoingQuery(msg, proto, qry, dns.rdataclass.IN, qtype, qname, "127.0.0.1", None, ecs)
+ # Check the answer
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufIncomingResponse(msg, proto, ans, length=responseSize)
+
+ self.checkNoRemainingMessage()
+
+ # this query should use the unmapped ECS
+ name = 'mx1.secure.example.'
+ expected = list()
+
+ for qname, qtype, proto, responseSize, ecs in [
+ ('mx1.secure.example.', dns.rdatatype.A, dnsmessage_pb2.PBDNSMessage.UDP, 173, "127.0.0.1"),
+ ]:
+ if not qname:
+ expected.append((None, None, None, None, None, None, None))
+ continue
+ ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 32)
+ query = dns.message.make_query(qname, qtype, use_edns=True, want_dnssec=True, options=[ecso], payload=512)
+ resp = dns.message.make_response(query)
+ expected.append((
+ qname, qtype, query, resp, proto, responseSize, ecs
+ ))
+
+ query = dns.message.make_query(name, 'A', want_dnssec=True)
+ query.flags |= dns.flags.RD
+ res = self.sendUDPQuery(query)
+
+ for qname, qtype, qry, ans, proto, responseSize, ecs in expected:
+ if not qname:
+ self.getFirstProtobufMessage()
+ self.getFirstProtobufMessage()
+ continue
+
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufOutgoingQuery(msg, proto, qry, dns.rdataclass.IN, qtype, qname, "127.0.0.1", None, ecs)
+ # Check the answer
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufIncomingResponse(msg, proto, ans, length=responseSize)
+
+ self.checkNoRemainingMessage()
+
class OutgoingProtobufNoQueriesTest(TestRecursorProtobuf):
"""
This test makes sure that we correctly export incoming responses but not outgoing queries over protobuf.
_config_template = """
# Switch off QName Minimization, it generates much more protobuf messages
# (or make the test much more smart!)
- qname-minimization=no"""
+ qname-minimization=no
+ max-cache-ttl=600
+ loglevel=9
+"""
_lua_config_file = """
outgoingProtobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=true })
""" % (protobufServersParameters[0].port, protobufServersParameters[1].port)
def testA(self):
+ # There is a race in priming (having the . DNSKEY in cache in particular) and this code.
+ # So make sure we have the . DNSKEY in cache
+ query = dns.message.make_query('.', 'A', want_dnssec=True)
+ query.flags |= dns.flags.RD
+ res = self.sendUDPQuery(query)
+ time.sleep(1)
+ self.emptyProtoBufQueue()
+
name = 'host1.secure.example.'
expected = list()
# the root DNSKEY has been learned with priming the root NS already
self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
tags = [ self._tag_from_gettag ] + self._tags
+ #print(msg)
self.checkProtobufTags(msg, tags)
self.checkNoRemainingMessage()
+ # Again to check PC case
+ res = self.sendUDPQuery(query)
+ self.assertRRsetInAnswer(res, expected)
+
+ # check the protobuf messages corresponding to the UDP query and answer
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+ self.checkProtobufTags(msg, [ self._tag_from_gettag ])
+ # then the response
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+ self.assertEqual(len(msg.response.rrs), 1)
+ rr = msg.response.rrs[0]
+ # time may have passed, so do not check TTL
+ self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False)
+ self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+ tags = [ self._tag_from_gettag ] + self._tags
+ self.checkProtobufTags(msg, tags)
+ self.checkNoRemainingMessage()
+
+class ProtobufTagCacheTest(TestRecursorProtobuf):
+ """
+ This test makes sure that we correctly cache tags (actually not cache them)
+ """
+
+ _confdir = 'ProtobufTagCache'
+ _config_template = """
+auth-zones=example=configs/%s/example.zone""" % _confdir
+ _lua_config_file = """
+ protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=true } )
+ """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
+ _lua_dns_script_file = """
+ function gettag(remote, ednssubnet, localip, qname, qtype, ednsoptions, tcp)
+ if qname:equal('tagged.example.') then
+ return 0, { '' .. math.random() }
+ end
+ return 0
+ end
+ """
+
+ def testTagged(self):
+ name = 'tagged.example.'
+ expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.84')
+ query = dns.message.make_query(name, 'A', want_dnssec=True)
+ query.flags |= dns.flags.CD
+ res = self.sendUDPQuery(query)
+ self.assertRRsetInAnswer(res, expected)
+
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+ self.assertEqual(len(msg.response.rrs), 1)
+ rr = msg.response.rrs[0]
+ # we have max-cache-ttl set to 15
+ self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+ self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+ self.checkNoRemainingMessage()
+ self.assertEqual(len(msg.response.tags), 1)
+ ts1 = msg.response.tags[0]
+
+ # Again to check PC case
+ res = self.sendUDPQuery(query)
+ self.assertRRsetInAnswer(res, expected)
+
+ msg = self.getFirstProtobufMessage()
+ self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+ self.assertEqual(len(msg.response.rrs), 1)
+ rr = msg.response.rrs[0]
+ # time may have passed, so do not check TTL
+ self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False)
+ self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+ self.checkNoRemainingMessage()
+ self.assertEqual(len(msg.response.tags), 1)
+ ts2 = msg.response.tags[0]
+ self.assertNotEqual(ts1, ts2)
+
class ProtobufSelectedFromLuaTest(TestRecursorProtobuf):
"""
This test makes sure that we correctly export queries and responses but only if they have been selected from Lua.