]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add regression tests for the proxy protocol TLV FFI accessor
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 17 Sep 2024 08:52:43 +0000 (10:52 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 17 Sep 2024 08:52:43 +0000 (10:52 +0200)
regression-tests.dnsdist/test_ProxyProtocol.py

index 78677b3a7149f7e9a4775e6264b78f36947717ba..b4d2c09f60e5fa343c112d8395f0d2ece35b3a03 100644 (file)
@@ -444,7 +444,7 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
 
     def testIncomingProxyDest(self):
         """
-        Incoming Proxy Protocol: values from Lua
+        Incoming Proxy Protocol: get forwarded destination
         """
         name = 'get-forwarded-dest.proxy-protocol-incoming.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
@@ -862,6 +862,136 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
             backgroundThreads[backgroundThread] = False
         cls.killProcess(cls._dnsdist)
 
+class TestProxyProtocolIncomingValuesViaLua(DNSDistTest):
+    """
+    Check that dnsdist can retrieve incoming Proxy Protocol TLV values via Lua
+    """
+
+    _config_template = """
+    setProxyProtocolACL( { "127.0.0.1/32" } )
+
+    function checkValues(dq)
+      if dq.localaddr:toStringWithPort() ~= '[2001:db8::9]:9999' then
+        return DNSAction.Spoof, "invalid.local.addr."
+      end
+      if dq.remoteaddr:toStringWithPort() ~= '[2001:db8::8]:8888' then
+        return DNSAction.Spoof, "invalid.remote.addr."
+      end
+      local values = dq:getProxyProtocolValues()
+      if #values ~= 3 then
+        return DNSAction.Spoof, #values .. ".invalid.values.count."
+      end
+      if values[2] ~= 'foo' then
+        return DNSAction.Spoof, "2.foo.value.missing."
+      end
+      if values[3] ~= 'proxy' then
+        return DNSAction.Spoof, "3.proxy.value.missing."
+      end
+      return DNSAction.Spoof, "ok."
+    end
+
+    local ffi = require("ffi")
+    local C = ffi.C
+    ffi.cdef[[
+      typedef unsigned int socklen_t;
+      const char *inet_ntop(int af, const void *restrict src,
+                      char *restrict dst, socklen_t size);
+    ]]
+    local ret_ptr = ffi.new("const char *[1]")
+    local ret_ptr_param = ffi.cast("const void **", ret_ptr)
+    local ret_size = ffi.new("size_t[1]")
+    local ret_size_param = ffi.cast("size_t*", ret_size)
+    local ret_pp_ptr = ffi.new("const dnsdist_ffi_proxy_protocol_value_t*[1]")
+    local ret_pp_ptr_param = ffi.cast("const dnsdist_ffi_proxy_protocol_value_t**", ret_pp_ptr)
+    local inet_buffer = ffi.new("char[?]", 256)
+
+    function sendResult(dqffi, str)
+      C.dnsdist_ffi_dnsquestion_set_result(dqffi, str, #str)
+      return DNSAction.Spoof
+    end
+
+    function checkValuesFFI(dqffi)
+      C.dnsdist_ffi_dnsquestion_get_localaddr(dqffi, ret_ptr_param, ret_size_param)
+      local addr = C.inet_ntop(10, ret_ptr[0], inet_buffer, 256)
+      if addr == nil or ffi.string(addr) ~= '2001:db8::9' then
+        return sendResult(dqffi, "invalid.local.addr.")
+      end
+      C.dnsdist_ffi_dnsquestion_get_remoteaddr(dqffi, ret_ptr_param, ret_size_param)
+      local addr = C.inet_ntop(10, ret_ptr[0], inet_buffer, 256)
+      if addr == nil or ffi.string(addr) ~= '2001:db8::8' then
+        return sendResult(dqffi, "invalid.remote.addr.")
+      end
+
+      local count = tonumber(C.dnsdist_ffi_dnsquestion_get_proxy_protocol_values(dqffi, ret_pp_ptr_param))
+      if count ~= 2 then
+        return sendResult(dqffi, count .. ".invalid.values.count.")
+      end
+
+      local foo_seen = false
+      local proxy_seen = false
+      for counter = 0, count - 1 do
+        local entry = ret_pp_ptr[0][counter]
+        if entry.type == 2 and ffi.string(entry.value, entry.size) == 'foo' then
+          foo_seen = true
+        elseif entry.type == 3 and ffi.string(entry.value, entry.size) == 'proxy' then
+          proxy_seen = true
+        end
+      end
+      if not foo_seen then
+        return sendResult(dqffi, "2.foo.value.missing.")
+      end
+      if not proxy_seen then
+        return sendResult(dqffi, "3.proxy.value.missing.")
+      end
+
+      return sendResult(dqffi, "ok.")
+    end
+
+    addAction("proxy-protocol-incoming-values-via-lua.tests.powerdns.com.", LuaAction(checkValues))
+    addAction("proxy-protocol-incoming-values-via-lua-ffi.tests.powerdns.com.", LuaFFIAction(checkValuesFFI))
+
+    newServer{address="127.0.0.1:%d"}
+    """
+
+    def testProxyUDPWithValuesFromLua(self):
+        """
+        Incoming Proxy Protocol: values from Lua
+        """
+        destAddr = "2001:db8::9"
+        destPort = 9999
+        srcAddr = "2001:db8::8"
+        srcPort = 8888
+        names = ['proxy-protocol-incoming-values-via-lua.tests.powerdns.com.',
+                 'proxy-protocol-incoming-values-via-lua-ffi.tests.powerdns.com.'
+                 ]
+        for name in names:
+            query = dns.message.make_query(name, 'A', 'IN')
+            # dnsdist set RA = RD for spoofed responses
+            query.flags &= ~dns.flags.RD
+            response = dns.message.make_response(query)
+
+            expectedResponse = dns.message.make_response(query)
+            rrset = dns.rrset.from_text(name,
+                                        60,
+                                        dns.rdataclass.IN,
+                                        dns.rdatatype.CNAME,
+                                        'ok.')
+            expectedResponse.answer.append(rrset)
+
+            udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+            (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+            conn = self.openTCPConnection(2.0)
+            try:
+                conn.send(udpPayload)
+                conn.send(struct.pack("!H", len(query.to_wire())))
+                conn.send(query.to_wire())
+                receivedResponse = self.recvTCPResponseOverConnection(conn)
+            except socket.timeout:
+                print('timeout')
+            self.assertEqual(expectedResponse, receivedResponse)
+
 class TestProxyProtocolNotExpected(DNSDistTest):
     """
     dnsdist is configured to expect a Proxy Protocol header on incoming queries but not from 127.0.0.1