]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_Basics.py
Merge pull request #8668 from cmouse/apex-dname
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Basics.py
index e9c122e705af210f18fd59a70b955f44d67fe579..c2dc85f426bb51063c4ab6715a25306b6842f1d9 100644 (file)
@@ -1,5 +1,4 @@
 #!/usr/bin/env python
-import unittest
 import dns
 import clientsubnetoption
 from dnsdisttests import DNSDistTest
@@ -9,15 +8,15 @@ class TestBasics(DNSDistTest):
     _config_template = """
     newServer{address="127.0.0.1:%s"}
     truncateTC(true)
-    addAction(AndRule{QTypeRule(dnsdist.ANY), TCPRule(false)}, TCAction())
-    addAction(RegexRule("evil[0-9]{4,}\\\\.regex\\\\.tests\\\\.powerdns\\\\.com$"), RCodeAction(dnsdist.REFUSED))
+    addAction(AndRule{QTypeRule(DNSQType.ANY), TCPRule(false)}, TCAction())
+    addAction(RegexRule("evil[0-9]{4,}\\\\.regex\\\\.tests\\\\.powerdns\\\\.com$"), RCodeAction(DNSRCode.REFUSED))
     mySMN = newSuffixMatchNode()
     mySMN:add(newDNSName("nameAndQtype.tests.powerdns.com."))
-    addAction(AndRule{SuffixMatchNodeRule(mySMN), QTypeRule("TXT")}, RCodeAction(dnsdist.NOTIMP))
+    addAction(AndRule{SuffixMatchNodeRule(mySMN), QTypeRule("TXT")}, RCodeAction(DNSRCode.NOTIMP))
     addAction(makeRule("drop.test.powerdns.com."), DropAction())
-    addAction(AndRule({QTypeRule(dnsdist.A),QNameRule("ds9a.nl")}), SpoofAction("1.2.3.4"))
-    addAction(newDNSName("dnsname.addaction.powerdns.com."), RCodeAction(dnsdist.REFUSED))
-    addAction({newDNSName("dnsname-table1.addaction.powerdns.com."), newDNSName("dnsname-table2.addaction.powerdns.com.")}, RCodeAction(dnsdist.REFUSED))
+    addAction(AndRule({QTypeRule(DNSQType.A),QNameRule("ds9a.nl")}), SpoofAction("1.2.3.4"))
+    addAction(newDNSName("dnsname.addaction.powerdns.com."), RCodeAction(DNSRCode.REFUSED))
+    addAction({newDNSName("dnsname-table1.addaction.powerdns.com."), newDNSName("dnsname-table2.addaction.powerdns.com.")}, RCodeAction(DNSRCode.REFUSED))
     """
 
     def testDropped(self):
@@ -30,11 +29,10 @@ class TestBasics(DNSDistTest):
         """
         name = 'drop.test.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, None)
 
     def testAWithECS(self):
         """
@@ -52,15 +50,12 @@ class TestBasics(DNSDistTest):
 
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testSimpleA(self):
         """
@@ -76,19 +71,14 @@ class TestBasics(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAnyIsTruncated(self):
         """
@@ -100,6 +90,8 @@ class TestBasics(DNSDistTest):
         """
         name = 'any.tests.powerdns.com.'
         query = dns.message.make_query(name, 'ANY', 'IN')
+        # dnsdist sets RA = RD for TC responses
+        query.flags &= ~dns.flags.RD
         expectedResponse = dns.message.make_response(query)
         expectedResponse.flags |= dns.flags.TC
 
@@ -142,6 +134,48 @@ class TestBasics(DNSDistTest):
 
         response.answer.append(rrset)
         response.flags |= dns.flags.TC
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.flags |= dns.flags.TC
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(expectedResponse.flags, receivedResponse.flags)
+        self.assertEquals(expectedResponse.question, receivedResponse.question)
+        self.assertFalse(response.answer == receivedResponse.answer)
+        self.assertEquals(len(receivedResponse.answer), 0)
+        self.assertEquals(len(receivedResponse.authority), 0)
+        self.assertEquals(len(receivedResponse.additional), 0)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+    def testTruncateTCEDNS(self):
+        """
+        Basics: Truncate TC with EDNS
+
+        dnsdist is configured to truncate TC (default),
+        we make the backend send responses
+        with TC set and additional content,
+        and check that the received response has been fixed.
+        Note that the query and initial response had EDNS,
+        so the final response should have it too.
+        """
+        name = 'atruncatetc.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=True)
+        response = dns.message.make_response(query)
+        # force a different responder payload than the one in the query,
+        # so we check that we don't just mirror it
+        response.payload = 4242
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+
+        response.answer.append(rrset)
+        response.flags |= dns.flags.TC
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.payload = 4242
+        expectedResponse.flags |= dns.flags.TC
 
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         receivedQuery.id = query.id
@@ -152,6 +186,11 @@ class TestBasics(DNSDistTest):
         self.assertEquals(len(receivedResponse.answer), 0)
         self.assertEquals(len(receivedResponse.authority), 0)
         self.assertEquals(len(receivedResponse.additional), 0)
+        print(expectedResponse)
+        print(receivedResponse)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 4242)
 
     def testRegexReturnsRefused(self):
         """
@@ -164,14 +203,14 @@ class TestBasics(DNSDistTest):
         """
         name = 'evil4242.regex.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
+        query.flags &= ~dns.flags.RD
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testQNameReturnsSpoofed(self):
         """
@@ -191,12 +230,10 @@ class TestBasics(DNSDistTest):
                                     '1.2.3.4')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testDomainAndQTypeReturnsNotImplemented(self):
         """
@@ -209,14 +246,14 @@ class TestBasics(DNSDistTest):
         """
         name = 'nameAndQtype.tests.powerdns.com.'
         query = dns.message.make_query(name, 'TXT', 'IN')
+        query.flags &= ~dns.flags.RD
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NOTIMP)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testDomainWithoutQTypeIsNotAffected(self):
         """
@@ -237,19 +274,14 @@ class TestBasics(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testOtherDomainANDQTypeIsNotAffected(self):
         """
@@ -270,19 +302,14 @@ class TestBasics(DNSDistTest):
                                     'nothing to see here')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testWrongResponse(self):
         """
@@ -306,17 +333,13 @@ class TestBasics(DNSDistTest):
                                     'nothing to see here')
         unrelatedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, unrelatedResponse)
-        self.assertTrue(receivedQuery)
-        self.assertEquals(receivedResponse, None)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, unrelatedResponse)
-        self.assertTrue(receivedQuery)
-        self.assertEquals(receivedResponse, None)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, unrelatedResponse)
+            self.assertTrue(receivedQuery)
+            self.assertEquals(receivedResponse, None)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
 
     def testHeaderOnlyRefused(self):
         """
@@ -328,17 +351,13 @@ class TestBasics(DNSDistTest):
         response.set_rcode(dns.rcode.REFUSED)
         response.question = []
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
     def testHeaderOnlyNoErrorResponse(self):
         """
@@ -349,17 +368,13 @@ class TestBasics(DNSDistTest):
         response = dns.message.make_response(query)
         response.question = []
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
 
     def testHeaderOnlyNXDResponse(self):
         """
@@ -371,17 +386,13 @@ class TestBasics(DNSDistTest):
         response.set_rcode(dns.rcode.NXDOMAIN)
         response.question = []
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
 
     def testAddActionDNSName(self):
         """
@@ -389,11 +400,14 @@ class TestBasics(DNSDistTest):
         """
         name = 'dnsname.addaction.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
+        query.flags &= ~dns.flags.RD
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testAddActionDNSNames(self):
         """
@@ -401,12 +415,12 @@ class TestBasics(DNSDistTest):
         """
         for name in ['dnsname-table{}.addaction.powerdns.com.'.format(i) for i in range(1,2)]:
             query = dns.message.make_query(name, 'A', 'IN')
+            query.flags &= ~dns.flags.RD
             expectedResponse = dns.message.make_response(query)
             expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-            (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-            self.assertEquals(receivedResponse, expectedResponse)
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (_, receivedResponse) = sender(query, response=None, useQueue=False)
+                self.assertEquals(receivedResponse, expectedResponse)
 
-if __name__ == '__main__':
-    unittest.main()
-    exit(0)