]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add regression tests for the proxy protocol
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 24 Feb 2020 14:40:22 +0000 (15:40 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 17 Mar 2020 13:12:55 +0000 (14:12 +0100)
regression-tests.dnsdist/proxyprotocol.py [new file with mode: 0644]
regression-tests.dnsdist/test_ProxyProtocol.py [new file with mode: 0644]

diff --git a/regression-tests.dnsdist/proxyprotocol.py b/regression-tests.dnsdist/proxyprotocol.py
new file mode 100644 (file)
index 0000000..cc34b9f
--- /dev/null
@@ -0,0 +1,122 @@
+#!/usr/bin/env python
+
+import socket
+import struct
+
+class ProxyProtocol(object):
+    MAGIC = b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
+    # Header is magic + versioncommand (1) + family (1) + content length (2)
+    HEADER_SIZE = len(MAGIC) + 1 + 1 + 2
+    PORT_SIZE = 2
+
+    def consumed(self):
+        return self.offset
+
+    def parseHeader(self, data):
+        if len(data) < self.HEADER_SIZE:
+            return False
+
+        if data[:len(self.MAGIC)] != self.MAGIC:
+            return False
+
+        self.version = int(data[12]) >> 4
+        if self.version != 0x02:
+            return False
+
+        self.command = int(data[12]) & ~0x20
+        self.local = False
+        self.offset = self.HEADER_SIZE
+
+        if self.command == 0x00:
+            self.local = True
+        elif self.command == 0x01:
+            self.family = int(data[13]) >> 4
+            if self.family == 0x01:
+                self.addrSize = 4
+            elif self.family == 0x02:
+                self.addrSize = 16
+            else:
+                return False
+
+            self.protocol = int(data[13]) & ~0xF0
+            if self.protocol == 0x01:
+                self.tcp = True
+            elif self.protocol == 0x02:
+                self.tcp = False
+            else:
+                return False
+        else:
+            return False
+
+        self.contentLen = struct.unpack("!H", data[14:16])[0]
+
+        if not self.local:
+            if self.contentLen < (self.addrSize * 2 + self.PORT_SIZE * 2):
+                return False
+
+        return True
+
+    def getAddr(self, data):
+        if len(data) < (self.consumed() + self.addrSize):
+            return False
+
+        value = None
+        if self.family == 0x01:
+            value = socket.inet_ntop(socket.AF_INET, data[self.offset:self.offset + self.addrSize])
+        else:
+            value = socket.inet_ntop(socket.AF_INET6, data[self.offset:self.offset + self.addrSize])
+
+        self.offset = self.offset + self.addrSize
+        return value
+
+    def getPort(self, data):
+        if len(data) < (self.consumed() + self.PORT_SIZE):
+            return False
+
+        value = struct.unpack('!H', data[self.offset:self.offset + self.PORT_SIZE])[0]
+        self.offset = self.offset + self.PORT_SIZE
+        return value
+
+    def parseAddressesAndPorts(self, data):
+        if self.local:
+            return True
+
+        if len(data) < (self.consumed() + self.addrSize * 2 + self.PORT_SIZE * 2):
+            return False
+
+        self.source = self.getAddr(data)
+        self.destination = self.getAddr(data)
+        self.sourcePort = self.getPort(data)
+        self.destinationPort = self.getPort(data)
+        return True
+
+    def parseAdditionalValues(self, data):
+        self.values = []
+        if self.local:
+            return True
+
+        if len(data) < (self.HEADER_SIZE + self.contentLen):
+            return False
+
+        remaining = self.HEADER_SIZE + self.contentLen - self.consumed()
+        if len(data) < remaining:
+            return False
+
+        while remaining >= 3:
+            valueType = data[self.offset]
+            self.offset = self.offset + 1
+            valueLen = struct.unpack("!H", data[self.offset:self.offset+2])[0]
+            self.offset = self.offset + 2
+
+            remaining = remaining - 3
+            if valueLen > 0:
+                if valueLen > remaining:
+                    return False
+                self.values.append([valueType, data[self.offset:self.offset+valueLen]])
+                self.offset = self.offset + valueLen
+                remaining = remaining - valueLen
+
+            else:
+                self.values.append([valueType, ""])
+
+        return True
diff --git a/regression-tests.dnsdist/test_ProxyProtocol.py b/regression-tests.dnsdist/test_ProxyProtocol.py
new file mode 100644 (file)
index 0000000..955e5d9
--- /dev/null
@@ -0,0 +1,302 @@
+#!/usr/bin/env python
+
+import dns
+import socket
+import struct
+import sys
+import threading
+
+from dnsdisttests import DNSDistTest
+from proxyprotocol import ProxyProtocol
+
+# Python2/3 compatibility hacks
+try:
+  from queue import Queue
+except ImportError:
+  from Queue import Queue
+
+def ProxyProtocolUDPResponder(port, fromQueue, toQueue):
+    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+    try:
+        sock.bind(("127.0.0.1", port))
+    except socket.error as e:
+        print("Error binding in the Proxy Protocol UDP responder: %s" % str(e))
+        sys.exit(1)
+
+    while True:
+        data, addr = sock.recvfrom(4096)
+
+        proxy = ProxyProtocol()
+        if len(data) < proxy.HEADER_SIZE:
+            continue
+
+        if not proxy.parseHeader(data):
+            continue
+
+        if proxy.local:
+            # likely a healthcheck
+            data = data[proxy.HEADER_SIZE:]
+            request = dns.message.from_wire(data)
+            response = dns.message.make_response(request)
+            wire = response.to_wire()
+            sock.settimeout(2.0)
+            sock.sendto(wire, addr)
+            sock.settimeout(None)
+
+            continue
+
+        payload = data[:(proxy.HEADER_SIZE + proxy.contentLen)]
+        dnsData = data[(proxy.HEADER_SIZE + proxy.contentLen):]
+        toQueue.put([payload, dnsData], True, 2.0)
+        # computing the correct ID for the response
+        request = dns.message.from_wire(dnsData)
+        response = fromQueue.get(True, 2.0)
+        response.id = request.id
+
+        sock.settimeout(2.0)
+        sock.sendto(response.to_wire(), addr)
+        sock.settimeout(None)
+
+    sock.close()
+
+def ProxyProtocolTCPResponder(port, fromQueue, toQueue):
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+    sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+    try:
+        sock.bind(("127.0.0.1", port))
+    except socket.error as e:
+        print("Error binding in the TCP responder: %s" % str(e))
+        sys.exit(1)
+
+    sock.listen(100)
+    while True:
+        (conn, _) = sock.accept()
+        conn.settimeout(5.0)
+        # try to read the entire Proxy Protocol header
+        proxy = ProxyProtocol()
+        header = conn.recv(proxy.HEADER_SIZE)
+        if not header:
+            conn.close()
+            continue
+
+        if not proxy.parseHeader(header):
+            conn.close()
+            continue
+
+        proxyContent = conn.recv(proxy.contentLen)
+        if not proxyContent:
+            conn.close()
+            continue
+
+        payload = header + proxyContent
+
+        data = conn.recv(2)
+        (datalen,) = struct.unpack("!H", data)
+
+        data = conn.recv(datalen)
+
+        toQueue.put([payload, data], True, 2.0)
+
+        response = fromQueue.get(True, 2.0)
+        if not response:
+            conn.close()
+            continue
+
+        # computing the correct ID for the response
+        request = dns.message.from_wire(data)
+        response.id = request.id
+
+        wire = response.to_wire()
+        conn.send(struct.pack("!H", len(wire)))
+        conn.send(wire)
+        conn.close()
+
+    sock.close()
+
+toProxyQueue = Queue()
+fromProxyQueue = Queue()
+proxyResponderPort = 5470
+
+udpResponder = threading.Thread(name='UDP Proxy Protocol Responder', target=ProxyProtocolUDPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
+udpResponder.setDaemon(True)
+udpResponder.start()
+tcpResponder = threading.Thread(name='TCP Proxy Protocol Responder', target=ProxyProtocolTCPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
+tcpResponder.setDaemon(True)
+tcpResponder.start()
+
+class ProxyProtocolTest(DNSDistTest):
+    _proxyResponderPort = proxyResponderPort
+    _config_params = ['_proxyResponderPort']
+
+    def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[]):
+      proxy = ProxyProtocol()
+      self.assertTrue(proxy.parseHeader(receivedProxyPayload))
+      self.assertEquals(proxy.version, 0x02)
+      self.assertEquals(proxy.command, 0x01)
+      self.assertEquals(proxy.family, 0x01)
+      if not isTCP:
+        self.assertEquals(proxy.protocol, 0x02)
+      else:
+        self.assertEquals(proxy.protocol, 0x01)
+      self.assertGreater(proxy.contentLen, 0)
+
+      self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload))
+      self.assertEquals(proxy.source, source)
+      self.assertEquals(proxy.destination, destination)
+      #self.assertEquals(proxy.sourcePort, sourcePort)
+      self.assertEquals(proxy.destinationPort, self._dnsDistPort)
+
+      self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload))
+      proxy.values.sort()
+      values.sort()
+      self.assertEquals(proxy.values, values)
+
+class TestProxyProtocol(ProxyProtocolTest):
+    """
+    dnsdist is configured to prepend a Proxy Protocol header to the query
+    """
+
+    _config_template = """
+    newServer{address="127.0.0.1:%d", useProxyProtocol=true}
+
+    function addValues(dq)
+      local values = { ["0"]="foo", ["42"]="bar" }
+      dq:setProxyProtocolValues(values)
+      return DNSAction.None
+    end
+
+    addAction("values-lua.proxy.tests.powerdns.com.", LuaAction(addValues))
+    """
+    _config_params = ['_proxyResponderPort']
+
+    def testProxyUDP(self):
+        """
+        Proxy Protocol: no value (UDP)
+        """
+        name = 'simple-udp.proxy.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        toProxyQueue.put(response, True, 2.0)
+
+        data = query.to_wire()
+        self._sock.send(data)
+        receivedResponse = None
+        try:
+            self._sock.settimeout(2.0)
+            data = self._sock.recv(4096)
+        except socket.timeout:
+            print('timeout')
+            data = None
+        if data:
+            receivedResponse = dns.message.from_wire(data)
+
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        self.assertTrue(receivedResponse)
+
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', False)
+
+    def testProxyTCP(self):
+      """
+        Proxy Protocol: no value (TCP)
+      """
+      name = 'simple-tcp.proxy.tests.powerdns.com.'
+      query = dns.message.make_query(name, 'A', 'IN')
+      response = dns.message.make_response(query)
+
+      toProxyQueue.put(response, True, 2.0)
+
+      conn = self.openTCPConnection(2.0)
+      data = query.to_wire()
+      self.sendTCPQueryOverConnection(conn, data, rawQuery=True)
+      receivedResponse = None
+      try:
+        receivedResponse = self.recvTCPResponseOverConnection(conn)
+      except socket.timeout:
+            print('timeout')
+
+      (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+      self.assertTrue(receivedProxyPayload)
+      self.assertTrue(receivedDNSData)
+      self.assertTrue(receivedResponse)
+
+      receivedQuery = dns.message.from_wire(receivedDNSData)
+      receivedQuery.id = query.id
+      receivedResponse.id = response.id
+      self.assertEquals(receivedQuery, query)
+      self.assertEquals(receivedResponse, response)
+      self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True)
+
+    def testProxyUDPWithValuesFromLua(self):
+        """
+        Proxy Protocol: values from Lua (UDP)
+        """
+        name = 'values-lua.proxy.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        toProxyQueue.put(response, True, 2.0)
+
+        data = query.to_wire()
+        self._sock.send(data)
+        receivedResponse = None
+        try:
+            self._sock.settimeout(2.0)
+            data = self._sock.recv(4096)
+        except socket.timeout:
+            print('timeout')
+            data = None
+        if data:
+            receivedResponse = dns.message.from_wire(data)
+
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        self.assertTrue(receivedResponse)
+
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', False, [ [0, b'foo'] , [ 42, b'bar'] ])
+
+    def testProxyTCPWithValuesFromLua(self):
+      """
+        Proxy Protocol: values from Lua (TCP)
+      """
+      name = 'values-lua.proxy.tests.powerdns.com.'
+      query = dns.message.make_query(name, 'A', 'IN')
+      response = dns.message.make_response(query)
+
+      toProxyQueue.put(response, True, 2.0)
+
+      conn = self.openTCPConnection(2.0)
+      data = query.to_wire()
+      self.sendTCPQueryOverConnection(conn, data, rawQuery=True)
+      receivedResponse = None
+      try:
+        receivedResponse = self.recvTCPResponseOverConnection(conn)
+      except socket.timeout:
+            print('timeout')
+
+      (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+      self.assertTrue(receivedProxyPayload)
+      self.assertTrue(receivedDNSData)
+      self.assertTrue(receivedResponse)
+
+      receivedQuery = dns.message.from_wire(receivedDNSData)
+      receivedQuery.id = query.id
+      receivedResponse.id = response.id
+      self.assertEquals(receivedQuery, query)
+      self.assertEquals(receivedResponse, response)
+      self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'] , [ 42, b'bar'] ])