]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_RecordsCount.py
Merge pull request #8945 from rgacogne/ddist-x-forwarded-for
[thirdparty/pdns.git] / regression-tests.dnsdist / test_RecordsCount.py
index cfd3d614bc067e2ffd2f883d62c5420a60f8f089..07741dba39d5a9237ae6174991d2fdd80cf4cf1c 100644 (file)
@@ -7,7 +7,7 @@ from dnsdisttests import DNSDistTest
 class TestRecordsCountOnlyOneAR(DNSDistTest):
 
     _config_template = """
-    addAction(NotRule(RecordsCountRule(DNSSection.Additional, 1, 1)), RCodeAction(dnsdist.REFUSED))
+    addAction(NotRule(RecordsCountRule(DNSSection.Additional, 1, 1)), RCodeAction(DNSRCode.REFUSED))
     newServer{address="127.0.0.1:%s"}
     """
 
@@ -20,14 +20,14 @@ class TestRecordsCountOnlyOneAR(DNSDistTest):
         """
         name = 'refuseemptyar.recordscount.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
+        query.flags &= ~dns.flags.RD
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testRecordsCountAllowOneAR(self):
         """
@@ -45,19 +45,14 @@ class TestRecordsCountOnlyOneAR(DNSDistTest):
                                                    dns.rdatatype.A,
                                                    '127.0.0.1'))
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testRecordsCountRefuseTwoAR(self):
         """
@@ -68,6 +63,7 @@ class TestRecordsCountOnlyOneAR(DNSDistTest):
         """
         name = 'refusetwoar.recordscount.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True)
+        query.flags &= ~dns.flags.RD
         query.additional.append(dns.rrset.from_text(name,
                                                     3600,
                                                     dns.rdataclass.IN,
@@ -76,17 +72,16 @@ class TestRecordsCountOnlyOneAR(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestRecordsCountMoreThanOneLessThanFour(DNSDistTest):
 
     _config_template = """
     addAction(RecordsCountRule(DNSSection.Answer, 2, 3), AllowAction())
-    addAction(AllRule(), RCodeAction(dnsdist.REFUSED))
+    addAction(AllRule(), RCodeAction(DNSRCode.REFUSED))
     newServer{address="127.0.0.1:%s"}
     """
 
@@ -99,14 +94,14 @@ class TestRecordsCountMoreThanOneLessThanFour(DNSDistTest):
         """
         name = 'refusenoan.recordscount.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
+        query.flags &= ~dns.flags.RD
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testRecordsCountAllowTwoAN(self):
         """
@@ -126,19 +121,14 @@ class TestRecordsCountMoreThanOneLessThanFour(DNSDistTest):
         response = dns.message.make_response(query)
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testRecordsCountRefuseFourAN(self):
         """
@@ -149,6 +139,7 @@ class TestRecordsCountMoreThanOneLessThanFour(DNSDistTest):
         """
         name = 'refusefouran.recordscount.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True)
+        query.flags &= ~dns.flags.RD
         rrset = dns.rrset.from_text_list(name,
                                          3600,
                                          dns.rdataclass.IN,
@@ -160,17 +151,16 @@ class TestRecordsCountMoreThanOneLessThanFour(DNSDistTest):
         expectedResponse.set_rcode(dns.rcode.REFUSED)
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestRecordsCountNothingInNS(DNSDistTest):
 
     _config_template = """
     addAction(RecordsCountRule(DNSSection.Authority, 0, 0), AllowAction())
-    addAction(AllRule(), RCodeAction(dnsdist.REFUSED))
+    addAction(AllRule(), RCodeAction(DNSRCode.REFUSED))
     newServer{address="127.0.0.1:%s"}
     """
 
@@ -189,15 +179,15 @@ class TestRecordsCountNothingInNS(DNSDistTest):
                                     dns.rdatatype.NS,
                                     'ns.tests.powerdns.com.')
         query.authority.append(rrset)
+        query.flags &= ~dns.flags.RD
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
         expectedResponse.authority.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 
     def testRecordsCountAllowEmptyNS(self):
@@ -216,24 +206,19 @@ class TestRecordsCountNothingInNS(DNSDistTest):
                                                    dns.rdatatype.A,
                                                    '127.0.0.1'))
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestRecordsCountNoOPTInAR(DNSDistTest):
 
     _config_template = """
-    addAction(NotRule(RecordsTypeCountRule(DNSSection.Additional, dnsdist.OPT, 0, 0)), RCodeAction(dnsdist.REFUSED))
+    addAction(NotRule(RecordsTypeCountRule(DNSSection.Additional, DNSQType.OPT, 0, 0)), RCodeAction(DNSRCode.REFUSED))
     newServer{address="127.0.0.1:%s"}
     """
 
@@ -246,14 +231,14 @@ class TestRecordsCountNoOPTInAR(DNSDistTest):
         """
         name = 'refuseoptinar.recordscount.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True)
+        query.flags &= ~dns.flags.RD
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testRecordsCountAllowNoOPTInAR(self):
         """
@@ -271,19 +256,14 @@ class TestRecordsCountNoOPTInAR(DNSDistTest):
                                                    dns.rdatatype.A,
                                                    '127.0.0.1'))
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testRecordsCountAllowTwoARButNoOPT(self):
         """
@@ -312,16 +292,11 @@ class TestRecordsCountNoOPTInAR(DNSDistTest):
                                                    dns.rdatatype.A,
                                                    '127.0.0.1'))
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)