]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_ProxyProtocol.py
Merge pull request #13702 from omoerbeek/rec-log-ref-wrapper
[thirdparty/pdns.git] / regression-tests.dnsdist / test_ProxyProtocol.py
index dd4ca4fbefb9dc010d537ae3e1f6a53c33a3e889..2ed60e08bc9f4351b7a59da384030f4b3f7691a6 100644 (file)
@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 
-import copy
 import dns
 import selectors
 import socket
@@ -12,6 +11,7 @@ import time
 
 from dnsdisttests import DNSDistTest, pickAvailablePort
 from proxyprotocol import ProxyProtocol
+from proxyprotocolutils import ProxyProtocolUDPResponder, ProxyProtocolTCPResponder
 from dnsdistdohtests import DNSDistDOHTest
 
 # Python2/3 compatibility hacks
@@ -20,118 +20,6 @@ try:
 except ImportError:
   from Queue import Queue
 
-def ProxyProtocolUDPResponder(port, fromQueue, toQueue):
-    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
-    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
-    try:
-        sock.bind(("127.0.0.1", port))
-    except socket.error as e:
-        print("Error binding in the Proxy Protocol UDP responder: %s" % str(e))
-        sys.exit(1)
-
-    while True:
-        data, addr = sock.recvfrom(4096)
-
-        proxy = ProxyProtocol()
-        if len(data) < proxy.HEADER_SIZE:
-            continue
-
-        if not proxy.parseHeader(data):
-            continue
-
-        if proxy.local:
-            # likely a healthcheck
-            data = data[proxy.HEADER_SIZE:]
-            request = dns.message.from_wire(data)
-            response = dns.message.make_response(request)
-            wire = response.to_wire()
-            sock.settimeout(2.0)
-            sock.sendto(wire, addr)
-            sock.settimeout(None)
-
-            continue
-
-        payload = data[:(proxy.HEADER_SIZE + proxy.contentLen)]
-        dnsData = data[(proxy.HEADER_SIZE + proxy.contentLen):]
-        toQueue.put([payload, dnsData], True, 2.0)
-        # computing the correct ID for the response
-        request = dns.message.from_wire(dnsData)
-        response = fromQueue.get(True, 2.0)
-        response.id = request.id
-
-        sock.settimeout(2.0)
-        sock.sendto(response.to_wire(), addr)
-        sock.settimeout(None)
-
-    sock.close()
-
-def ProxyProtocolTCPResponder(port, fromQueue, toQueue):
-    # be aware that this responder will not accept a new connection
-    # until the last one has been closed. This is done on purpose to
-    # to check for connection reuse, making sure that a lot of connections
-    # are not opened in parallel.
-    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
-    sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
-    try:
-        sock.bind(("127.0.0.1", port))
-    except socket.error as e:
-        print("Error binding in the TCP responder: %s" % str(e))
-        sys.exit(1)
-
-    sock.listen(100)
-    while True:
-        (conn, _) = sock.accept()
-        conn.settimeout(5.0)
-        # try to read the entire Proxy Protocol header
-        proxy = ProxyProtocol()
-        header = conn.recv(proxy.HEADER_SIZE)
-        if not header:
-            conn.close()
-            continue
-
-        if not proxy.parseHeader(header):
-            conn.close()
-            continue
-
-        proxyContent = conn.recv(proxy.contentLen)
-        if not proxyContent:
-            conn.close()
-            continue
-
-        payload = header + proxyContent
-        while True:
-          try:
-            data = conn.recv(2)
-          except socket.timeout:
-            data = None
-
-          if not data:
-            conn.close()
-            break
-
-          (datalen,) = struct.unpack("!H", data)
-          data = conn.recv(datalen)
-
-          toQueue.put([payload, data], True, 2.0)
-
-          response = copy.deepcopy(fromQueue.get(True, 2.0))
-          if not response:
-            conn.close()
-            break
-
-          # computing the correct ID for the response
-          request = dns.message.from_wire(data)
-          response.id = request.id
-
-          wire = response.to_wire()
-          conn.send(struct.pack("!H", len(wire)))
-          conn.send(wire)
-
-        conn.close()
-
-    sock.close()
-
 toProxyQueue = Queue()
 fromProxyQueue = Queue()
 proxyResponderPort = pickAvailablePort()
@@ -492,7 +380,7 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
     addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=true})
     addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=false})
     setProxyProtocolACL( { "127.0.0.1/32" } )
-    newServer{address="127.0.0.1:%d", useProxyProtocol=true}
+    newServer{address="127.0.0.1:%d", useProxyProtocol=true, proxyProtocolAdvertiseTLS=true}
 
     function addValues(dq)
       dq:addProxyProtocolValue(0, 'foo')
@@ -789,7 +677,7 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
         receivedResponse.id = response.id
         self.assertEqual(receivedQuery, query)
         self.assertEqual(receivedResponse, response)
-        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
 
         for idx in range(5):
           receivedResponse = None
@@ -805,7 +693,7 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
           receivedResponse.id = response.id
           self.assertEqual(receivedQuery, query)
           self.assertEqual(receivedResponse, response)
-          self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+          self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
 
     def testProxyDoHSeveralQueriesOverConnectionPPInside(self):
         """
@@ -842,7 +730,7 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
         receivedResponse.id = response.id
         self.assertEqual(receivedQuery, query)
         self.assertEqual(receivedResponse, response)
-        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
 
         for idx in range(5):
           receivedResponse = None
@@ -858,7 +746,7 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
           receivedResponse.id = response.id
           self.assertEqual(receivedQuery, query)
           self.assertEqual(receivedResponse, response)
-          self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+          self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
 
     @classmethod
     def tearDownClass(cls):
@@ -936,6 +824,74 @@ class TestProxyProtocolNotExpected(DNSDistTest):
           print('timeout')
         self.assertEqual(receivedResponse, None)
 
+class TestProxyProtocolNotAllowedOnBind(DNSDistTest):
+    """
+    dnsdist is configured to expect a Proxy Protocol header on incoming queries but not on the 127.0.0.1 bind
+    """
+    _skipListeningOnCL = True
+    _config_template = """
+    -- proxy protocol payloads are not allowed on this bind address!
+    addLocal('127.0.0.1:%d', {enableProxyProtocol=false})
+    setProxyProtocolACL( { "127.0.0.1/8" } )
+    newServer{address="127.0.0.1:%d"}
+    """
+    # NORMAL responder, does not expect a proxy protocol payload!
+    _config_params = ['_dnsDistPort', '_testServerPort']
+
+    def testNoHeader(self):
+        """
+        Unexpected Proxy Protocol: no header
+        """
+        # no proxy protocol header and none is expected from this source, should be passed on
+        name = 'no-header.unexpected-proxy-protocol.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+          sender = getattr(self, method)
+          (receivedQuery, receivedResponse) = sender(query, response)
+          receivedQuery.id = query.id
+          self.assertEqual(query, receivedQuery)
+          self.assertEqual(response, receivedResponse)
+
+    def testIncomingProxyDest(self):
+        """
+        Unexpected Proxy Protocol: should be dropped
+        """
+        name = 'with-proxy-payload.unexpected-protocol-incoming.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        # Make sure that the proxy payload does NOT turn into a legal qname
+        destAddr = "ff:db8::ffff"
+        destPort = 65535
+        srcAddr = "ff:db8::ffff"
+        srcPort = 65535
+
+        udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+        (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
+        self.assertEqual(receivedResponse, None)
+
+        tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+        wire = query.to_wire()
+
+        receivedResponse = None
+        try:
+          conn = self.openTCPConnection(2.0)
+          conn.send(tcpPayload)
+          conn.send(struct.pack("!H", len(wire)))
+          conn.send(wire)
+          receivedResponse = self.recvTCPResponseOverConnection(conn)
+        except socket.timeout:
+          print('timeout')
+        self.assertEqual(receivedResponse, None)
+
 class TestDOHWithOutgoingProxyProtocol(DNSDistDOHTest):
 
     _serverKey = 'server.key'