]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_Basics.py
Merge pull request #8713 from rgacogne/auth-strict-caches-size
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Basics.py
index ea4b7b08b9a94d81d4e828cd60fd83fd3c0915cc..c2dc85f426bb51063c4ab6715a25306b6842f1d9 100644 (file)
@@ -1,5 +1,4 @@
 #!/usr/bin/env python
-import unittest
 import dns
 import clientsubnetoption
 from dnsdisttests import DNSDistTest
@@ -9,15 +8,15 @@ class TestBasics(DNSDistTest):
     _config_template = """
     newServer{address="127.0.0.1:%s"}
     truncateTC(true)
-    addAction(AndRule{QTypeRule(dnsdist.ANY), TCPRule(false)}, TCAction())
-    addAction(RegexRule("evil[0-9]{4,}\\\\.regex\\\\.tests\\\\.powerdns\\\\.com$"), RCodeAction(dnsdist.REFUSED))
+    addAction(AndRule{QTypeRule(DNSQType.ANY), TCPRule(false)}, TCAction())
+    addAction(RegexRule("evil[0-9]{4,}\\\\.regex\\\\.tests\\\\.powerdns\\\\.com$"), RCodeAction(DNSRCode.REFUSED))
     mySMN = newSuffixMatchNode()
     mySMN:add(newDNSName("nameAndQtype.tests.powerdns.com."))
-    addAction(AndRule{SuffixMatchNodeRule(mySMN), QTypeRule("TXT")}, RCodeAction(dnsdist.NOTIMP))
+    addAction(AndRule{SuffixMatchNodeRule(mySMN), QTypeRule("TXT")}, RCodeAction(DNSRCode.NOTIMP))
     addAction(makeRule("drop.test.powerdns.com."), DropAction())
-    addAction(AndRule({QTypeRule(dnsdist.A),QNameRule("ds9a.nl")}), SpoofAction("1.2.3.4"))
-    addAction(newDNSName("dnsname.addaction.powerdns.com."), RCodeAction(dnsdist.REFUSED))
-    addAction({newDNSName("dnsname-table1.addaction.powerdns.com."), newDNSName("dnsname-table2.addaction.powerdns.com.")}, RCodeAction(dnsdist.REFUSED))
+    addAction(AndRule({QTypeRule(DNSQType.A),QNameRule("ds9a.nl")}), SpoofAction("1.2.3.4"))
+    addAction(newDNSName("dnsname.addaction.powerdns.com."), RCodeAction(DNSRCode.REFUSED))
+    addAction({newDNSName("dnsname-table1.addaction.powerdns.com."), newDNSName("dnsname-table2.addaction.powerdns.com.")}, RCodeAction(DNSRCode.REFUSED))
     """
 
     def testDropped(self):
@@ -30,11 +29,10 @@ class TestBasics(DNSDistTest):
         """
         name = 'drop.test.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, None)
 
     def testAWithECS(self):
         """
@@ -52,15 +50,12 @@ class TestBasics(DNSDistTest):
 
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testSimpleA(self):
         """
@@ -76,19 +71,14 @@ class TestBasics(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAnyIsTruncated(self):
         """
@@ -100,6 +90,8 @@ class TestBasics(DNSDistTest):
         """
         name = 'any.tests.powerdns.com.'
         query = dns.message.make_query(name, 'ANY', 'IN')
+        # dnsdist sets RA = RD for TC responses
+        query.flags &= ~dns.flags.RD
         expectedResponse = dns.message.make_response(query)
         expectedResponse.flags |= dns.flags.TC
 
@@ -211,14 +203,14 @@ class TestBasics(DNSDistTest):
         """
         name = 'evil4242.regex.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
+        query.flags &= ~dns.flags.RD
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testQNameReturnsSpoofed(self):
         """
@@ -238,12 +230,10 @@ class TestBasics(DNSDistTest):
                                     '1.2.3.4')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testDomainAndQTypeReturnsNotImplemented(self):
         """
@@ -256,14 +246,14 @@ class TestBasics(DNSDistTest):
         """
         name = 'nameAndQtype.tests.powerdns.com.'
         query = dns.message.make_query(name, 'TXT', 'IN')
+        query.flags &= ~dns.flags.RD
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NOTIMP)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testDomainWithoutQTypeIsNotAffected(self):
         """
@@ -284,19 +274,14 @@ class TestBasics(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testOtherDomainANDQTypeIsNotAffected(self):
         """
@@ -317,19 +302,14 @@ class TestBasics(DNSDistTest):
                                     'nothing to see here')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testWrongResponse(self):
         """
@@ -353,17 +333,13 @@ class TestBasics(DNSDistTest):
                                     'nothing to see here')
         unrelatedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, unrelatedResponse)
-        self.assertTrue(receivedQuery)
-        self.assertEquals(receivedResponse, None)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, unrelatedResponse)
-        self.assertTrue(receivedQuery)
-        self.assertEquals(receivedResponse, None)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, unrelatedResponse)
+            self.assertTrue(receivedQuery)
+            self.assertEquals(receivedResponse, None)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
 
     def testHeaderOnlyRefused(self):
         """
@@ -375,17 +351,13 @@ class TestBasics(DNSDistTest):
         response.set_rcode(dns.rcode.REFUSED)
         response.question = []
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
     def testHeaderOnlyNoErrorResponse(self):
         """
@@ -396,17 +368,13 @@ class TestBasics(DNSDistTest):
         response = dns.message.make_response(query)
         response.question = []
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
 
     def testHeaderOnlyNXDResponse(self):
         """
@@ -418,17 +386,13 @@ class TestBasics(DNSDistTest):
         response.set_rcode(dns.rcode.NXDOMAIN)
         response.question = []
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
 
     def testAddActionDNSName(self):
         """
@@ -436,11 +400,14 @@ class TestBasics(DNSDistTest):
         """
         name = 'dnsname.addaction.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
+        query.flags &= ~dns.flags.RD
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testAddActionDNSNames(self):
         """
@@ -448,12 +415,12 @@ class TestBasics(DNSDistTest):
         """
         for name in ['dnsname-table{}.addaction.powerdns.com.'.format(i) for i in range(1,2)]:
             query = dns.message.make_query(name, 'A', 'IN')
+            query.flags &= ~dns.flags.RD
             expectedResponse = dns.message.make_response(query)
             expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-            (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-            self.assertEquals(receivedResponse, expectedResponse)
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (_, receivedResponse) = sender(query, response=None, useQueue=False)
+                self.assertEquals(receivedResponse, expectedResponse)
 
-if __name__ == '__main__':
-    unittest.main()
-    exit(0)