cls._responsesCounter[threading.currentThread().name] = 1
@classmethod
- def _getResponse(cls, request, fromQueue, toQueue):
+ def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
response = None
if len(request.question) != 1:
print("Skipping query with question count %d" % (len(request.question)))
healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
if healthCheck:
cls._healthCheckCounter += 1
+ response = dns.message.make_response(request)
else:
cls._ResponderIncrementCounter()
if not fromQueue.empty():
- response = fromQueue.get(True, cls._queueTimeout)
- if response:
- response = copy.copy(response)
- response.id = request.id
- toQueue.put(request, True, cls._queueTimeout)
+ toQueue.put(request, True, cls._queueTimeout)
+ if synthesize is None:
+ response = fromQueue.get(True, cls._queueTimeout)
+ if response:
+ response = copy.copy(response)
+ response.id = request.id
if not response:
- if healthCheck:
+ if synthesize is not None:
response = dns.message.make_response(request)
+ response.set_rcode(synthesize)
elif cls._answerUnexpected:
response = dns.message.make_response(request)
response.set_rcode(dns.rcode.SERVFAIL)
return response
@classmethod
- def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False):
+ def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False):
+ ignoreTrailing = trailingDataResponse is True
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind(("127.0.0.1", port))
while True:
data, addr = sock.recvfrom(4096)
- request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
- response = cls._getResponse(request, fromQueue, toQueue)
-
+ forceRcode = None
+ try:
+ request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
+ except dns.message.TrailingJunk as e:
+ if trailingDataResponse is False:
+ raise
+ print("UDP query with trailing data, synthesizing response")
+ request = dns.message.from_wire(data, ignore_trailing=True)
+ forceRcode = trailingDataResponse
+
+ response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
if not response:
continue
sock.close()
@classmethod
- def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleResponses=False):
+ def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False):
+ ignoreTrailing = trailingDataResponse is True
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
try:
(datalen,) = struct.unpack("!H", data)
data = conn.recv(datalen)
- request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
- response = cls._getResponse(request, fromQueue, toQueue)
-
+ forceRcode = None
+ try:
+ request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
+ except dns.message.TrailingJunk as e:
+ if trailingDataResponse is False:
+ raise
+ print("TCP query with trailing data, synthesizing response")
+ request = dns.message.from_wire(data, ignore_trailing=True)
+ forceRcode = trailingDataResponse
+
+ response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
if not response:
conn.close()
continue
import dns
from dnsdisttests import DNSDistTest
-class TestTrailing(DNSDistTest):
+class TestTrailingDataToBackend(DNSDistTest):
# this test suite uses a different responder port
# because, contrary to the other ones, its
_testServerPort = 5360
_config_template = """
newServer{address="127.0.0.1:%s"}
- addAction(AndRule({QTypeRule(dnsdist.AAAA), TrailingDataRule()}), DropAction())
"""
@classmethod
def startResponders(cls):
print("Launching responders..")
- cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, True])
+ # Respond SERVFAIL to queries with trailing data.
+ cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.SERVFAIL])
cls._UDPResponder.setDaemon(True)
cls._UDPResponder.start()
- cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, True])
+ # Respond SERVFAIL to queries with trailing data.
+ cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.SERVFAIL])
cls._TCPResponder.setDaemon(True)
cls._TCPResponder.start()
- def testTrailingAllowed(self):
+ def testTrailingPassthrough(self):
"""
- Trailing: Allowed
+ Trailing data: Pass through
"""
- name = 'allowed.trailing.tests.powerdns.com.'
+ name = 'passthrough.trailing.tests.powerdns.com.'
query = dns.message.make_query(name, 'A', 'IN')
response = dns.message.make_response(query)
rrset = dns.rrset.from_text(name,
dns.rdatatype.A,
'127.0.0.1')
response.answer.append(rrset)
+ expectedResponse = dns.message.make_response(query)
+ expectedResponse.set_rcode(dns.rcode.SERVFAIL)
raw = query.to_wire()
raw = raw + b'A'* 20
- (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
- (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
+ # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
+ (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, expectedResponse)
+
+class TestTrailingDataToDnsdist(DNSDistTest):
+ _config_template = """
+ newServer{address="127.0.0.1:%s"}
+ addAction(AndRule({QNameRule("dropped.trailing.tests.powerdns.com."), TrailingDataRule()}), DropAction())
+ """
def testTrailingDropped(self):
"""
- Trailing: dropped
+ Trailing data: Drop query
"""
name = 'dropped.trailing.tests.powerdns.com.'
- query = dns.message.make_query(name, 'AAAA', 'IN')
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
raw = query.to_wire()
raw = raw + b'A'* 20
- (_, receivedResponse) = self.sendUDPQuery(raw, response=None, rawQuery=True)
- self.assertEquals(receivedResponse, None)
- (_, receivedResponse) = self.sendTCPQuery(raw, response=None, rawQuery=True)
- self.assertEquals(receivedResponse, None)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+
+ # Verify that queries with no trailing data make it through.
+ # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+ # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(response, receivedResponse)
+
+ # Verify that queries with trailing data don't make it through.
+ # (_, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
+ # (_, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
+ (_, receivedResponse) = sender(raw, response, rawQuery=True)
+ self.assertEquals(receivedResponse, None)