]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_TeeAction.py
Merge pull request #13980 from karelbilek/d_xfr
[thirdparty/pdns.git] / regression-tests.dnsdist / test_TeeAction.py
index 156dcda7be2965e47e10679538af5907c624fe83..0516fbc505c7d98a31b78bad59c5c9a14ad8f37f 100644 (file)
@@ -3,39 +3,48 @@ import base64
 import threading
 import clientsubnetoption
 import dns
-from dnsdisttests import DNSDistTest, Queue
+from dnsdisttests import DNSDistTest, Queue, pickAvailablePort
+from proxyprotocolutils import ProxyProtocolUDPResponder, ProxyProtocolTCPResponder
 
 class TestTeeAction(DNSDistTest):
 
     _consoleKey = DNSDistTest.generateConsoleKey()
     _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
-    _teeServerPort = 5390
+    _teeServerPort = pickAvailablePort()
+    _teeProxyServerPort = pickAvailablePort()
     _toTeeQueue = Queue()
     _fromTeeQueue = Queue()
+    _toTeeProxyQueue = Queue()
+    _fromTeeProxyQueue = Queue()
     _config_template = """
     setKey("%s")
     controlSocket("127.0.0.1:%s")
     newServer{address="127.0.0.1:%d"}
     addAction(QTypeRule(DNSQType.A), TeeAction("127.0.0.1:%d", true))
     addAction(QTypeRule(DNSQType.AAAA), TeeAction("127.0.0.1:%d", false))
+    addAction(QTypeRule(DNSQType.ANY), TeeAction("127.0.0.1:%d", false, '127.0.0.1', true))
     """
-    _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_teeServerPort', '_teeServerPort']
+    _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_teeServerPort', '_teeServerPort', '_teeProxyServerPort']
     @classmethod
     def startResponders(cls):
         print("Launching responders..")
 
         cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
-        cls._UDPResponder.setDaemon(True)
+        cls._UDPResponder.daemon = True
         cls._UDPResponder.start()
 
         cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, True])
-        cls._TCPResponder.setDaemon(True)
+        cls._TCPResponder.daemon = True
         cls._TCPResponder.start()
 
         cls._TeeResponder = threading.Thread(name='Tee Responder', target=cls.UDPResponder, args=[cls._teeServerPort, cls._toTeeQueue, cls._fromTeeQueue])
-        cls._TeeResponder.setDaemon(True)
+        cls._TeeResponder.daemon = True
         cls._TeeResponder.start()
 
+        cls._TeeProxyResponder = threading.Thread(name='Proxy Protocol Tee Responder', target=ProxyProtocolUDPResponder, args=[cls._teeProxyServerPort, cls._toTeeProxyQueue, cls._fromTeeProxyQueue])
+        cls._TeeProxyResponder.daemon = True
+        cls._TeeProxyResponder.start()
+
     def testTeeWithECS(self):
         """
         TeeAction: ECS
@@ -60,8 +69,8 @@ class TestTeeAction(DNSDistTest):
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
             receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(response, receivedResponse)
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
 
             # retrieve the query from the Tee server
             teedQuery = self._fromTeeQueue.get(True, 2.0)
@@ -72,7 +81,7 @@ class TestTeeAction(DNSDistTest):
 
         # check the TeeAction stats
         stats = self.sendConsoleCommand("getAction(0):printStats()")
-        self.assertEquals(stats, """noerrors\t%d
+        self.assertEqual(stats, """noerrors\t%d
 nxdomains\t0
 other-rcode\t0
 queries\t%d
@@ -108,8 +117,8 @@ tcp-drops\t0
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
             receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(response, receivedResponse)
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
 
             # retrieve the query from the Tee server
             teedQuery = self._fromTeeQueue.get(True, 2.0)
@@ -120,7 +129,53 @@ tcp-drops\t0
 
         # check the TeeAction stats
         stats = self.sendConsoleCommand("getAction(0):printStats()")
-        self.assertEquals(stats, """noerrors\t%d
+        self.assertEqual(stats, """noerrors\t%d
+nxdomains\t0
+other-rcode\t0
+queries\t%d
+recv-errors\t0
+refuseds\t0
+responses\t%d
+send-errors\t0
+servfails\t0
+tcp-drops\t0
+""" % (numberOfQueries, numberOfQueries, numberOfQueries))
+
+    def testTeeWithProxy(self):
+        """
+        TeeAction: Proxy
+        """
+        name = 'proxy.tee.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'ANY', 'IN')
+        response = dns.message.make_response(query)
+
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        numberOfQueries = 10
+        for _ in range(numberOfQueries):
+            # push the response to the Tee Proxy server
+            self._toTeeProxyQueue.put(response, True, 2.0)
+
+            (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+
+            # retrieve the query from the Tee Proxy server
+            [payload, teedQuery] = self._fromTeeProxyQueue.get(True, 2.0)
+            self.checkMessageNoEDNS(query, dns.message.from_wire(teedQuery))
+            self.checkMessageProxyProtocol(payload, '127.0.0.1', '127.0.0.1', False)
+
+        # check the TeeAction stats
+        stats = self.sendConsoleCommand("getAction(0):printStats()")
+        self.assertEqual(stats, """noerrors\t%d
 nxdomains\t0
 other-rcode\t0
 queries\t%d