class TestRecordsCountOnlyOneAR(DNSDistTest):
_config_template = """
- addAction(NotRule(RecordsCountRule(DNSSection.Additional, 1, 1)), RCodeAction(dnsdist.REFUSED))
+ addAction(NotRule(RecordsCountRule(DNSSection.Additional, 1, 1)), RCodeAction(DNSRCode.REFUSED))
newServer{address="127.0.0.1:%s"}
"""
expectedResponse = dns.message.make_response(query)
expectedResponse.set_rcode(dns.rcode.REFUSED)
- (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
- self.assertEquals(receivedResponse, expectedResponse)
-
- (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
- self.assertEquals(receivedResponse, expectedResponse)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (_, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertEquals(receivedResponse, expectedResponse)
def testRecordsCountAllowOneAR(self):
"""
dns.rdatatype.A,
'127.0.0.1'))
- (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
-
- (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(response, receivedResponse)
def testRecordsCountRefuseTwoAR(self):
"""
expectedResponse = dns.message.make_response(query)
expectedResponse.set_rcode(dns.rcode.REFUSED)
- (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
- self.assertEquals(receivedResponse, expectedResponse)
-
- (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
- self.assertEquals(receivedResponse, expectedResponse)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (_, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertEquals(receivedResponse, expectedResponse)
class TestRecordsCountMoreThanOneLessThanFour(DNSDistTest):
_config_template = """
addAction(RecordsCountRule(DNSSection.Answer, 2, 3), AllowAction())
- addAction(AllRule(), RCodeAction(dnsdist.REFUSED))
+ addAction(AllRule(), RCodeAction(DNSRCode.REFUSED))
newServer{address="127.0.0.1:%s"}
"""
expectedResponse = dns.message.make_response(query)
expectedResponse.set_rcode(dns.rcode.REFUSED)
- (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
- self.assertEquals(receivedResponse, expectedResponse)
-
- (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
- self.assertEquals(receivedResponse, expectedResponse)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (_, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertEquals(receivedResponse, expectedResponse)
def testRecordsCountAllowTwoAN(self):
"""
response = dns.message.make_response(query)
response.answer.append(rrset)
- (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
-
- (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(response, receivedResponse)
def testRecordsCountRefuseFourAN(self):
"""
expectedResponse.set_rcode(dns.rcode.REFUSED)
expectedResponse.answer.append(rrset)
- (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
- self.assertEquals(receivedResponse, expectedResponse)
-
- (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
- self.assertEquals(receivedResponse, expectedResponse)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (_, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertEquals(receivedResponse, expectedResponse)
class TestRecordsCountNothingInNS(DNSDistTest):
_config_template = """
addAction(RecordsCountRule(DNSSection.Authority, 0, 0), AllowAction())
- addAction(AllRule(), RCodeAction(dnsdist.REFUSED))
+ addAction(AllRule(), RCodeAction(DNSRCode.REFUSED))
newServer{address="127.0.0.1:%s"}
"""
expectedResponse.set_rcode(dns.rcode.REFUSED)
expectedResponse.authority.append(rrset)
- (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
- self.assertEquals(receivedResponse, expectedResponse)
-
- (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
- self.assertEquals(receivedResponse, expectedResponse)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (_, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertEquals(receivedResponse, expectedResponse)
def testRecordsCountAllowEmptyNS(self):
dns.rdatatype.A,
'127.0.0.1'))
- (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
-
- (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(response, receivedResponse)
class TestRecordsCountNoOPTInAR(DNSDistTest):
_config_template = """
- addAction(NotRule(RecordsTypeCountRule(DNSSection.Additional, dnsdist.OPT, 0, 0)), RCodeAction(dnsdist.REFUSED))
+ addAction(NotRule(RecordsTypeCountRule(DNSSection.Additional, DNSQType.OPT, 0, 0)), RCodeAction(DNSRCode.REFUSED))
newServer{address="127.0.0.1:%s"}
"""
expectedResponse = dns.message.make_response(query)
expectedResponse.set_rcode(dns.rcode.REFUSED)
- (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
- self.assertEquals(receivedResponse, expectedResponse)
-
- (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
- self.assertEquals(receivedResponse, expectedResponse)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (_, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertEquals(receivedResponse, expectedResponse)
def testRecordsCountAllowNoOPTInAR(self):
"""
dns.rdatatype.A,
'127.0.0.1'))
- (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
-
- (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(response, receivedResponse)
def testRecordsCountAllowTwoARButNoOPT(self):
"""
dns.rdatatype.A,
'127.0.0.1'))
- (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
-
- (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
- self.assertTrue(receivedQuery)
- self.assertTrue(receivedResponse)
- receivedQuery.id = query.id
- self.assertEquals(query, receivedQuery)
- self.assertEquals(response, receivedResponse)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(response, receivedResponse)