]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add Lua bindings for the incoming network interface
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 11 Apr 2025 09:28:51 +0000 (11:28 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 30 Apr 2025 10:57:02 +0000 (12:57 +0200)
This is useful in Virtual Routing and Forwarding (VRF) environments
where the destination IP address might not be enough to identify the VRF.

(cherry picked from commit 72a24734735bf9e0cceaefa54047d015a503e033)

pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdistdist/dnsdist-lua-ffi-interface.h
pdns/dnsdistdist/dnsdist-lua-ffi.cc
pdns/dnsdistdist/docs/reference/dq.rst
pdns/dnsdistdist/test-dnsdist-lua-ffi.cc
regression-tests.dnsdist/test_IncomingInterface.py [new file with mode: 0644]

index bd75754b3cb93b3d19a4c0d08a382f166f3da937..549902e4d22f93f0ccb9db10e7d6865d302f5dd1 100644 (file)
@@ -138,6 +138,13 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
       return dq.sni;
     });
 
+  luaCtx.registerFunction<std::string (DNSQuestion::*)() const>("getIncomingInterface", [](const DNSQuestion& dnsQuestion) -> std::string {
+    if (dnsQuestion.ids.cs != nullptr) {
+      return dnsQuestion.ids.cs->interface;
+    }
+    return {};
+  });
+
   luaCtx.registerFunction<std::string (DNSQuestion::*)()const>("getProtocol", [](const DNSQuestion& dq) {
     return dq.getProtocol().toPrettyString();
   });
@@ -458,6 +465,13 @@ private:
     return dnsResponse.ids.queryRealTime.udiff();
   });
 
+  luaCtx.registerFunction<std::string (DNSResponse::*)() const>("getIncomingInterface", [](const DNSResponse& dnsResponse) -> std::string {
+    if (dnsResponse.ids.cs != nullptr) {
+      return dnsResponse.ids.cs->interface;
+    }
+    return {};
+  });
+
   luaCtx.registerFunction<void(DNSResponse::*)(std::string)>("sendTrap", [](const DNSResponse& dr, boost::optional<std::string> reason) {
 #ifdef HAVE_NET_SNMP
       if (g_snmpAgent && g_snmpTrapsEnabled) {
index 7323975bad1b96ce5936437432f13a96941b98fc..5405a97c9d705630aa209c7bbf37f3fc981eaccc 100644 (file)
@@ -64,6 +64,7 @@ bool dnsdist_ffi_dnsquestion_is_remote_v6(const dnsdist_ffi_dnsquestion_t* dnsQu
 void dnsdist_ffi_dnsquestion_get_remoteaddr(const dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize) __attribute__ ((visibility ("default")));
 void dnsdist_ffi_dnsquestion_get_masked_remoteaddr(dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize, uint8_t bits) __attribute__ ((visibility ("default")));
 uint16_t dnsdist_ffi_dnsquestion_get_remote_port(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_dnsquestion_get_incoming_interface(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 void dnsdist_ffi_dnsquestion_get_qname_raw(const dnsdist_ffi_dnsquestion_t* dq, const char** qname, size_t* qnameSize) __attribute__ ((visibility ("default")));
 size_t dnsdist_ffi_dnsquestion_get_qname_hash(const dnsdist_ffi_dnsquestion_t* dq, size_t init) __attribute__ ((visibility ("default")));
 uint16_t dnsdist_ffi_dnsquestion_get_qtype(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
index f2af6c19f02d1a52771495b7134103895a3f782a..47c8fefa4c574bec3b0dfa2a0a553ed8d86a3b2f 100644 (file)
@@ -121,6 +121,14 @@ uint16_t dnsdist_ffi_dnsquestion_get_remote_port(const dnsdist_ffi_dnsquestion_t
   return dq->dq->ids.origRemote.getPort();
 }
 
+const char* dnsdist_ffi_dnsquestion_get_incoming_interface(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  if (dq == nullptr || dq->dq == nullptr || dq->dq->ids.cs == nullptr) {
+    return nullptr;
+  }
+  return dq->dq->ids.cs->interface.c_str();
+}
+
 void dnsdist_ffi_dnsquestion_get_qname_raw(const dnsdist_ffi_dnsquestion_t* dq, const char** qname, size_t* qnameSize)
 {
   const auto& storage = dq->dq->ids.qname.getStorage();
index c66f1257ad196e445439d74de63a9b2b27c2a02d..d82aa1683f9ab519745d43865a47e160bfee6fba 100644 (file)
@@ -182,6 +182,18 @@ This state can be modified from the various hooks.
 
     :returns: The scheme of the DoH query, for example ``http`` or ``https``
 
+  .. method:: DNSQuestion:getIncomingInterface() -> string
+
+    .. versionadded:: 2.0.0
+
+    Return the name of the network interface this query was received on, but only if the corresponding frontend
+    has been bound to a specific network interface via the ``interface`` parameter to :func:`addLocal`, :func:`setLocal`,
+    :func:`addTLSLocal`, :func:`addDOHLocal`, :func:`addDOQLocal` or :func:`AddDOH3Local`, or the ``interface`` parameter
+    of a :ref:`frontend <yaml-settings-BindConfiguration>` when the YAML format is used. This is useful in Virtual Routing
+    and Forwarding (VRF) environments where the destination IP address might not be enough to identify the VRF.
+
+    :returns: The name of the network interface this query was received on, or an empty string.
+
   .. method:: DNSQuestion:getProtocol() -> string
 
     .. versionadded:: 1.7.0
index ba70003a77ff586a680c6611eff16d3335a8aff3..4d8c31ee01d136c490bf8ce155da0ba671f56b7a 100644 (file)
@@ -373,6 +373,30 @@ BOOST_AUTO_TEST_CASE(test_Query)
   BOOST_CHECK_EQUAL(ids.d_protoBufData->d_deviceID, deviceID);
   BOOST_CHECK_EQUAL(ids.d_protoBufData->d_deviceName, deviceName);
   BOOST_CHECK_EQUAL(ids.d_protoBufData->d_requestorID, requestorID);
+
+  /* no frontend yet */
+  BOOST_CHECK(dnsdist_ffi_dnsquestion_get_incoming_interface(nullptr) == nullptr);
+  BOOST_CHECK(dnsdist_ffi_dnsquestion_get_incoming_interface(&lightDQ) == nullptr);
+  {
+    /* frontend without and interface set */
+    const std::string interface{};
+    ClientState frontend(ids.origDest, false, false, 0, interface, {}, false);
+    ids.cs = &frontend;
+    const auto* itfPtr = dnsdist_ffi_dnsquestion_get_incoming_interface(&lightDQ);
+    BOOST_REQUIRE(itfPtr != nullptr);
+    BOOST_CHECK_EQUAL(std::string(itfPtr), interface);
+    ids.cs = nullptr;
+  }
+  {
+    /* frontend with interface set */
+    const std::string interface{"interface-name-0"};
+    ClientState frontend(ids.origDest, false, false, 0, interface, {}, false);
+    ids.cs = &frontend;
+    const auto* itfPtr = dnsdist_ffi_dnsquestion_get_incoming_interface(&lightDQ);
+    BOOST_REQUIRE(itfPtr != nullptr);
+    BOOST_CHECK_EQUAL(std::string(itfPtr), interface);
+    ids.cs = nullptr;
+  }
 }
 
 BOOST_AUTO_TEST_CASE(test_Response)
diff --git a/regression-tests.dnsdist/test_IncomingInterface.py b/regression-tests.dnsdist/test_IncomingInterface.py
new file mode 100644 (file)
index 0000000..e02d3f2
--- /dev/null
@@ -0,0 +1,117 @@
+#!/usr/bin/env python
+import socket
+import unittest
+import dns
+from dnsdisttests import DNSDistTest
+
+def get_loopback_itf():
+    interfaces = socket.if_nameindex()
+    for itf in interfaces:
+        if itf[1] == 'lo':
+            return 'lo'
+    return None
+
+class TestIncomingInterface(DNSDistTest):
+    _lo_itf = get_loopback_itf()
+    _config_template = """
+    local itfName = '%s'
+    addLocal('127.0.0.1:%d', {interface=itfName})
+
+    function checkItf(dq)
+      if dq:getIncomingInterface() ~= itfName then
+        return DNSAction.Spoof, '1.2.3.4'
+      end
+      return DNSAction.None
+    end
+
+    function checkItfResponse(dr)
+      if dr:getIncomingInterface() ~= itfName then
+        return DNSResponseAction.ServFail
+      end
+      return DNSResponseAction.None
+    end
+
+    addAction(AllRule(), LuaAction(checkItf))
+    addResponseAction(AllRule(), LuaResponseAction(checkItfResponse))
+    newServer{address="127.0.0.1:%d"}
+    """
+    _config_params = ['_lo_itf', '_dnsDistPort', '_testServerPort']
+    _skipListeningOnCL = True
+
+    def testItfName(self):
+        """
+        Advanced: Check incoming interface name
+        """
+        if get_loopback_itf() is None:
+            raise unittest.SkipTest('No lo interface')
+
+        name = 'incoming-interface.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '4.3.2.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEqual(receivedQuery, query)
+            self.assertEqual(receivedResponse, response)
+
+class TestIncomingInterfaceNotSet(DNSDistTest):
+    _lo_itf = get_loopback_itf()
+    _config_template = """
+    local itfName = '%s'
+    addLocal('127.0.0.1:%d')
+
+    function checkItf(dq)
+      if dq:getIncomingInterface() ~= itfName then
+        return DNSAction.Spoof, '1.2.3.4'
+      end
+      return DNSAction.None
+    end
+
+    function checkItfResponse(dr)
+      if dr:getIncomingInterface() ~= itfName then
+        return DNSResponseAction.ServFail
+      end
+      return DNSResponseAction.None
+    end
+
+    addAction(AllRule(), LuaAction(checkItf))
+    addResponseAction(AllRule(), LuaResponseAction(checkItfResponse))
+    newServer{address="127.0.0.1:%d"}
+    """
+    _config_params = ['_lo_itf', '_dnsDistPort', '_testServerPort']
+    _skipListeningOnCL = True
+
+    def testItfName(self):
+        """
+        Advanced: Check incoming interface name (not set)
+        """
+        if get_loopback_itf() is None:
+            raise unittest.SkipTest('No lo interface')
+
+        name = 'incoming-interface-not-set.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '1.2.3.4')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEqual(receivedQuery, None)
+            self.assertEqual(receivedResponse, expectedResponse)