]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.recursor-dnssec/test_ProxyProtocol.py
Merge pull request #12808 from omoerbeek/args-delint
[thirdparty/pdns.git] / regression-tests.recursor-dnssec / test_ProxyProtocol.py
index ee521ea76a0b74f4de0eead8a9f362fa227bbac8..f28020f58dfd208bb7ed4cfdbf53907880204bd1 100644 (file)
@@ -5,6 +5,11 @@ import struct
 import sys
 import time
 
+try:
+    range = xrange
+except NameError:
+    pass
+
 from recursortests import RecursorTest
 from proxyprotocol import ProxyProtocol
 
@@ -27,62 +32,6 @@ class ProxyProtocolRecursorTest(RecursorTest):
     def tearDownClass(cls):
         cls.tearDownRecursor()
 
-    @classmethod
-    def sendUDPQueryWithProxyProtocol(cls, query, v6, source, destination, sourcePort, destinationPort, values=[], timeout=2.0):
-        queryPayload = query.to_wire()
-        ppPayload = ProxyProtocol.getPayload(False, False, v6, source, destination, sourcePort, destinationPort, values)
-        payload = ppPayload + queryPayload
-
-        if timeout:
-            cls._sock.settimeout(timeout)
-
-        try:
-            cls._sock.send(payload)
-            data = cls._sock.recv(4096)
-        except socket.timeout:
-            data = None
-        finally:
-            if timeout:
-                cls._sock.settimeout(None)
-
-        message = None
-        if data:
-            message = dns.message.from_wire(data)
-        return message
-
-    @classmethod
-    def sendTCPQueryWithProxyProtocol(cls, query, v6, source, destination, sourcePort, destinationPort, values=[], timeout=2.0):
-        queryPayload = query.to_wire()
-        ppPayload = ProxyProtocol.getPayload(False, False, v6, source, destination, sourcePort, destinationPort, values)
-
-        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        if timeout:
-            sock.settimeout(timeout)
-
-        sock.connect(("127.0.0.1", cls._recursorPort))
-
-        try:
-            sock.send(ppPayload)
-            sock.send(struct.pack("!H", len(queryPayload)))
-            sock.send(queryPayload)
-            data = sock.recv(2)
-            if data:
-                (datalen,) = struct.unpack("!H", data)
-                data = sock.recv(datalen)
-        except socket.timeout as e:
-            print("Timeout: %s" % (str(e)))
-            data = None
-        except socket.error as e:
-            print("Network error: %s" % (str(e)))
-            data = None
-        finally:
-            sock.close()
-
-        message = None
-        if data:
-            message = dns.message.from_wire(data)
-        return message
-
 class ProxyProtocolAllowedRecursorTest(ProxyProtocolRecursorTest):
     _confdir = 'ProxyProtocol'
     _lua_dns_script_file = """
@@ -449,6 +398,51 @@ class ProxyProtocolAllowedRecursorTest(ProxyProtocolRecursorTest):
             res = sender(query, True, '2001:db8::1', '2001:db8::ff', 0, 65535, [ [0, b'foo' ], [ 255, b'bar'] ])
             self.assertEqual(res, None)
 
+    def testIPv6ProxyProtocolSeveralQueriesOverTCP(self):
+        qname = 'several-queries-tcp.proxy-protocol.recursor-tests.powerdns.com.'
+        expected = dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '192.0.2.1')
+
+        query = dns.message.make_query(qname, 'A', want_dnssec=True)
+        queryPayload = query.to_wire()
+        ppPayload = ProxyProtocol.getPayload(False, True, True, '::42', '2001:db8::ff', 0, 65535, [ [0, b'foo' ], [ 255, b'bar'] ])
+        payload = ppPayload + queryPayload
+
+        # TCP
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.settimeout(2.0)
+        sock.connect(("127.0.0.1", self._recursorPort))
+
+        sock.send(ppPayload)
+
+        count = 0
+        for idx in range(5):
+            try:
+                sock.send(struct.pack("!H", len(queryPayload)))
+                sock.send(queryPayload)
+
+                data = sock.recv(2)
+                if data:
+                    (datalen,) = struct.unpack("!H", data)
+                    data = sock.recv(datalen)
+            except socket.timeout as e:
+                print("Timeout: %s" % (str(e)))
+                data = None
+                break
+            except socket.error as e:
+                print("Network error: %s" % (str(e)))
+                data = None
+                break
+
+            res = None
+            if data:
+                res = dns.message.from_wire(data)
+            self.assertRcodeEqual(res, dns.rcode.NOERROR)
+            self.assertRRsetInAnswer(res, expected)
+            count = count + 1
+
+        self.assertEqual(count, 5)
+        sock.close()
+
 class ProxyProtocolAllowedFFIRecursorTest(ProxyProtocolAllowedRecursorTest):
     # same tests than ProxyProtocolAllowedRecursorTest but with the Lua FFI interface instead of the regular one
     _confdir = 'ProxyProtocolFFI'