]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add regression tests for incoming Proxy Protocol
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 15 Oct 2020 15:34:17 +0000 (17:34 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 11 Jan 2021 09:22:00 +0000 (10:22 +0100)
regression-tests.dnsdist/test_ProxyProtocol.py

index 9dd9a3fac0e479723bb4f86411e4dc13470036a4..761caee597596a3a5743f2d69e9abfcb651472a7 100644 (file)
@@ -142,12 +142,15 @@ class ProxyProtocolTest(DNSDistTest):
     _proxyResponderPort = proxyResponderPort
     _config_params = ['_proxyResponderPort']
 
-    def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[]):
+    def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None):
       proxy = ProxyProtocol()
       self.assertTrue(proxy.parseHeader(receivedProxyPayload))
       self.assertEquals(proxy.version, 0x02)
       self.assertEquals(proxy.command, 0x01)
-      self.assertEquals(proxy.family, 0x01)
+      if v6:
+        self.assertEquals(proxy.family, 0x02)
+      else:
+        self.assertEquals(proxy.family, 0x01)
       if not isTCP:
         self.assertEquals(proxy.protocol, 0x02)
       else:
@@ -157,8 +160,12 @@ class ProxyProtocolTest(DNSDistTest):
       self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload))
       self.assertEquals(proxy.source, source)
       self.assertEquals(proxy.destination, destination)
-      #self.assertEquals(proxy.sourcePort, sourcePort)
-      self.assertEquals(proxy.destinationPort, self._dnsDistPort)
+      if sourcePort:
+        self.assertEquals(proxy.sourcePort, sourcePort)
+      if destinationPort:
+        self.assertEquals(proxy.destinationPort, destinationPort)
+      else:
+        self.assertEquals(proxy.destinationPort, self._dnsDistPort)
 
       self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload))
       proxy.values.sort()
@@ -410,3 +417,337 @@ class TestProxyProtocol(ProxyProtocolTest):
         self.assertEquals(receivedQuery, query)
         self.assertEquals(receivedResponse, response)
         self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [])
+
+class TestProxyProtocolIncoming(ProxyProtocolTest):
+    """
+    dnsdist is configured to prepend a Proxy Protocol header to the query and expect one on incoming queries
+    """
+
+    _config_template = """
+    setProxyProtocolACL( { "127.0.0.1/32" } )
+    newServer{address="127.0.0.1:%d", useProxyProtocol=true}
+
+    function addValues(dq)
+      dq:addProxyProtocolValue(0, 'foo')
+      dq:addProxyProtocolValue(42, 'bar')
+      return DNSAction.None
+    end
+
+    -- refuse queries with no TLV value type 2
+    addAction(NotRule(ProxyProtocolValueRule(2)), RCodeAction(DNSRCode.REFUSED))
+    -- or with a TLV value type 3 different from "proxy"
+    addAction(NotRule(ProxyProtocolValueRule(3, "proxy")), RCodeAction(DNSRCode.REFUSED))
+
+    function answerBasedOnForwardedDest(dq)
+      local port = dq.localaddr:getPort()
+      local dest = dq.localaddr:toString()
+      return DNSAction.Spoof, "address-was-"..dest.."-port-was-"..port..".proxy-protocol-incoming.tests.powerdns.com."
+    end
+    addAction("get-forwarded-dest.proxy-protocol-incoming.tests.powerdns.com.", LuaAction(answerBasedOnForwardedDest))
+
+    function answerBasedOnForwardedSrc(dq)
+      local port = dq.remoteaddr:getPort()
+      local src = dq.remoteaddr:toString()
+      return DNSAction.Spoof, "address-was-"..src.."-port-was-"..port..".proxy-protocol-incoming.tests.powerdns.com."
+    end
+    addAction("get-forwarded-src.proxy-protocol-incoming.tests.powerdns.com.", LuaAction(answerBasedOnForwardedSrc))
+
+    -- add these values for all queries
+    addAction("proxy-protocol-incoming.tests.powerdns.com.", LuaAction(addValues))
+    addAction("proxy-protocol-incoming.tests.powerdns.com.", AddProxyProtocolValueAction(1, "dnsdist"))
+    addAction("proxy-protocol-incoming.tests.powerdns.com.", AddProxyProtocolValueAction(255, "proxy-protocol"))
+
+    -- override all existing values
+    addAction("override.proxy-protocol-incoming.tests.powerdns.com.", SetProxyProtocolValuesAction({["50"]="overridden"}))
+    """
+    _config_params = ['_proxyResponderPort']
+    _verboseMode = True
+
+    def testNoHeader(self):
+        """
+        Incoming Proxy Protocol: no header
+        """
+        # no proxy protocol header while one is expected, should be dropped
+        name = 'no-header.incoming-proxy-protocol.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+          sender = getattr(self, method)
+          (_, receivedResponse) = sender(query, response=None)
+          self.assertEquals(receivedResponse, None)
+
+    def testIncomingProxyDest(self):
+        """
+        Incoming Proxy Protocol: values from Lua
+        """
+        name = 'get-forwarded-dest.proxy-protocol-incoming.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+
+        destAddr = "2001:db8::9"
+        destPort = 9999
+        srcAddr = "2001:db8::8"
+        srcPort = 8888
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.CNAME,
+                                    "address-was-{}-port-was-{}.proxy-protocol-incoming.tests.powerdns.com.".format(destAddr, destPort, self._dnsDistPort))
+        response.answer.append(rrset)
+
+        udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+        (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
+        self.assertEquals(receivedResponse, response)
+
+        tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+        wire = query.to_wire()
+
+        receivedResponse = None
+        try:
+          conn = self.openTCPConnection(2.0)
+          conn.send(tcpPayload)
+          conn.send(struct.pack("!H", len(wire)))
+          conn.send(wire)
+          receivedResponse = self.recvTCPResponseOverConnection(conn)
+        except socket.timeout:
+          print('timeout')
+        self.assertEquals(receivedResponse, response)
+
+    def testProxyUDPWithValuesFromLua(self):
+        """
+        Incoming Proxy Protocol: values from Lua (UDP)
+        """
+        name = 'values-lua.proxy-protocol-incoming.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        destAddr = "2001:db8::9"
+        destPort = 9999
+        srcAddr = "2001:db8::8"
+        srcPort = 8888
+        response = dns.message.make_response(query)
+
+        udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+        toProxyQueue.put(response, True, 2.0)
+        (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
+
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        self.assertTrue(receivedResponse)
+
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+        self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, False, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
+
+    def testProxyTCPWithValuesFromLua(self):
+        """
+        Incoming Proxy Protocol: values from Lua (TCP)
+        """
+        name = 'values-lua.proxy-protocol-incoming.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        destAddr = "2001:db8::9"
+        destPort = 9999
+        srcAddr = "2001:db8::8"
+        srcPort = 8888
+        response = dns.message.make_response(query)
+
+        tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+
+        toProxyQueue.put(response, True, 2.0)
+
+        wire = query.to_wire()
+
+        receivedResponse = None
+        try:
+          conn = self.openTCPConnection(2.0)
+          conn.send(tcpPayload)
+          conn.send(struct.pack("!H", len(wire)))
+          conn.send(wire)
+          receivedResponse = self.recvTCPResponseOverConnection(conn)
+        except socket.timeout:
+          print('timeout')
+        self.assertEquals(receivedResponse, response)
+
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        self.assertTrue(receivedResponse)
+
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+        self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
+
+    def testProxyUDPWithValueOverride(self):
+        """
+        Incoming Proxy Protocol: override existing value (UDP)
+        """
+        name = 'override.proxy-protocol-incoming.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        destAddr = "2001:db8::9"
+        destPort = 9999
+        srcAddr = "2001:db8::8"
+        srcPort = 8888
+        response = dns.message.make_response(query)
+
+        udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [2, b'foo'], [3, b'proxy'], [ 50, b'initial-value']])
+        toProxyQueue.put(response, True, 2.0)
+        (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
+
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        self.assertTrue(receivedResponse)
+
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+        self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, False, [ [50, b'overridden'] ], True, srcPort, destPort)
+
+    def testProxyTCPSeveralQueriesOverConnection(self):
+        """
+        Incoming Proxy Protocol: Several queries over the same connection (TCP)
+        """
+        name = 'several-queries.proxy-protocol-incoming.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        destAddr = "2001:db8::9"
+        destPort = 9999
+        srcAddr = "2001:db8::8"
+        srcPort = 8888
+        response = dns.message.make_response(query)
+
+        tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+
+        toProxyQueue.put(response, True, 2.0)
+
+        wire = query.to_wire()
+
+        receivedResponse = None
+        conn = self.openTCPConnection(2.0)
+        try:
+          conn.send(tcpPayload)
+          conn.send(struct.pack("!H", len(wire)))
+          conn.send(wire)
+          receivedResponse = self.recvTCPResponseOverConnection(conn)
+        except socket.timeout:
+          print('timeout')
+        self.assertEquals(receivedResponse, response)
+
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        self.assertTrue(receivedResponse)
+
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+        self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
+
+        for idx in range(5):
+          receivedResponse = None
+          toProxyQueue.put(response, True, 2.0)
+          try:
+            conn.send(struct.pack("!H", len(wire)))
+            conn.send(wire)
+            receivedResponse = self.recvTCPResponseOverConnection(conn)
+          except socket.timeout:
+            print('timeout')
+
+          self.assertEquals(receivedResponse, response)
+
+          (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+          self.assertTrue(receivedProxyPayload)
+          self.assertTrue(receivedDNSData)
+          self.assertTrue(receivedResponse)
+
+          receivedQuery = dns.message.from_wire(receivedDNSData)
+          receivedQuery.id = query.id
+          receivedResponse.id = response.id
+          self.assertEquals(receivedQuery, query)
+          self.assertEquals(receivedResponse, response)
+          self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
+
+class TestProxyProtocolNotExpected(DNSDistTest):
+    """
+    dnsdist is configured to expect a Proxy Protocol header on incoming queries but not from 127.0.0.1
+    """
+
+    _config_template = """
+    setProxyProtocolACL( { "192.0.2.1/32" } )
+    newServer{address="127.0.0.1:%d"}
+    """
+    # NORMAL responder, does not expect a proxy protocol payload!
+    _config_params = ['_testServerPort']
+    _verboseMode = True
+
+    def testNoHeader(self):
+        """
+        Unexpected Proxy Protocol: no header
+        """
+        # no proxy protocol header and none is expected from this source, should be passed on
+        name = 'no-header.unexpected-proxy-protocol.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+          sender = getattr(self, method)
+          (receivedQuery, receivedResponse) = sender(query, response)
+          receivedQuery.id = query.id
+          self.assertEquals(query, receivedQuery)
+          self.assertEquals(response, receivedResponse)
+
+    def testIncomingProxyDest(self):
+        """
+        Unexpected Proxy Protocol: should be dropped
+        """
+        name = 'with-proxy-payload.unexpected-protocol-incoming.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        # Make sure that the proxy payload does NOT turn into a legal qname
+        destAddr = "ff:db8::ffff"
+        destPort = 65535
+        srcAddr = "ff:db8::ffff"
+        srcPort = 65535
+
+        udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+        (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
+        self.assertEquals(receivedResponse, None)
+
+        tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+        wire = query.to_wire()
+
+        receivedResponse = None
+        try:
+          conn = self.openTCPConnection(2.0)
+          conn.send(tcpPayload)
+          conn.send(struct.pack("!H", len(wire)))
+          conn.send(wire)
+          receivedResponse = self.recvTCPResponseOverConnection(conn)
+        except socket.timeout:
+          print('timeout')
+        self.assertEquals(receivedResponse, None)