]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add Lua bindings for the incoming network interface
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 11 Apr 2025 09:28:51 +0000 (11:28 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 11 Apr 2025 09:29:48 +0000 (11:29 +0200)
This is useful in Virtual Routing and Forwarding (VRF) environments
where the destination IP address might not be enough to identify the VRF.

pdns/dnsdistdist/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdistdist/dnsdist-lua-ffi-interface.h
pdns/dnsdistdist/dnsdist-lua-ffi.cc
pdns/dnsdistdist/docs/reference/dq.rst
pdns/dnsdistdist/test-dnsdist-lua-ffi.cc
regression-tests.dnsdist/test_IncomingInterface.py [new file with mode: 0644]

index c5f204e8534fc8a59bbd00a31a1976bb3e76e57f..d73d3a158407eae35b75b773a48220a1ebc80055 100644 (file)
@@ -161,6 +161,13 @@ void setupLuaBindingsDNSQuestion([[maybe_unused]] LuaContext& luaCtx)
     return dnsQuestion.sni;
   });
 
+  luaCtx.registerFunction<std::string (DNSQuestion::*)() const>("getIncomingInterface", [](const DNSQuestion& dnsQuestion) -> std::string {
+    if (dnsQuestion.ids.cs != nullptr) {
+      return dnsQuestion.ids.cs->interface;
+    }
+    return {};
+  });
+
   luaCtx.registerFunction<std::string (DNSQuestion::*)() const>("getProtocol", [](const DNSQuestion& dnsQuestion) {
     return dnsQuestion.getProtocol().toPrettyString();
   });
@@ -502,6 +509,13 @@ void setupLuaBindingsDNSQuestion([[maybe_unused]] LuaContext& luaCtx)
     return dnsResponse.ids.queryRealTime.udiff();
   });
 
+  luaCtx.registerFunction<std::string (DNSResponse::*)() const>("getIncomingInterface", [](const DNSResponse& dnsResponse) -> std::string {
+    if (dnsResponse.ids.cs != nullptr) {
+      return dnsResponse.ids.cs->interface;
+    }
+    return {};
+  });
+
   luaCtx.registerFunction<void (DNSResponse::*)(std::string)>("sendTrap", []([[maybe_unused]] const DNSResponse& dnsResponse, [[maybe_unused]] boost::optional<std::string> reason) {
 #ifdef HAVE_NET_SNMP
     if (g_snmpAgent != nullptr && dnsdist::configuration::getImmutableConfiguration().d_snmpTrapsEnabled) {
index 5d6bb216c470e8aa4cc59e6767b7e6fab5d29635..00e7a25a95b0f361e30928281e87c01f2340dc12 100644 (file)
@@ -64,6 +64,7 @@ bool dnsdist_ffi_dnsquestion_is_remote_v6(const dnsdist_ffi_dnsquestion_t* dnsQu
 void dnsdist_ffi_dnsquestion_get_remoteaddr(const dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize) __attribute__ ((visibility ("default")));
 void dnsdist_ffi_dnsquestion_get_masked_remoteaddr(dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize, uint8_t bits) __attribute__ ((visibility ("default")));
 uint16_t dnsdist_ffi_dnsquestion_get_remote_port(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_dnsquestion_get_incoming_interface(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 void dnsdist_ffi_dnsquestion_get_qname_raw(const dnsdist_ffi_dnsquestion_t* dq, const char** qname, size_t* qnameSize) __attribute__ ((visibility ("default")));
 size_t dnsdist_ffi_dnsquestion_get_qname_hash(const dnsdist_ffi_dnsquestion_t* dq, size_t init) __attribute__ ((visibility ("default")));
 uint16_t dnsdist_ffi_dnsquestion_get_qtype(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
index 42269716b26b45be5a4258c069c1968d3348e3ff..6636f2b123b4be40f71efabdaf3b2354ab2338bd 100644 (file)
@@ -124,6 +124,14 @@ uint16_t dnsdist_ffi_dnsquestion_get_remote_port(const dnsdist_ffi_dnsquestion_t
   return dq->dq->ids.origRemote.getPort();
 }
 
+const char* dnsdist_ffi_dnsquestion_get_incoming_interface(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  if (dq == nullptr || dq->dq == nullptr || dq->dq->ids.cs == nullptr) {
+    return nullptr;
+  }
+  return dq->dq->ids.cs->interface.c_str();
+}
+
 void dnsdist_ffi_dnsquestion_get_qname_raw(const dnsdist_ffi_dnsquestion_t* dq, const char** qname, size_t* qnameSize)
 {
   const auto& storage = dq->dq->ids.qname.getStorage();
index 93efa81778be39ac251d9899342cdff797625b54..65337d7012b3670e575ea7bd01db24a10d20f731 100644 (file)
@@ -182,6 +182,18 @@ This state can be modified from the various hooks.
 
     :returns: The scheme of the DoH query, for example ``http`` or ``https``
 
+  .. method:: DNSQuestion:getIncomingInterface() -> string
+
+    .. versionadded:: 2.0.0
+
+    Return the name of the network interface this query was received on, but only if the corresponding frontend
+    has been bound to a specific network interface via the ``interface`` parameter to :func:`addLocal`, :func:`setLocal`,
+    :func:`addTLSLocal`, :func:`addDOHLocal`, :func:`addDOQLocal` or :func:`AddDOH3Local`, or the ``interface`` parameter
+    of a :ref:`frontend <yaml-settings-BindConfiguration>` when the YAML format is used. This is useful in Virtual Routing
+    and Forwarding (VRF) environments where the destination IP address might not be enough to identify the VRF.
+
+    :returns: The name of the network interface this query was received on, or an empty string.
+
   .. method:: DNSQuestion:getProtocol() -> string
 
     .. versionadded:: 1.7.0
index f1145625f6e1cd0a7b3ad6d54748df23e11bbb68..84ba2eb2adc0d88451bd50a712a9d19c3cc74426 100644 (file)
@@ -368,6 +368,30 @@ BOOST_AUTO_TEST_CASE(test_Query)
   BOOST_CHECK_EQUAL(ids.d_protoBufData->d_deviceID, deviceID);
   BOOST_CHECK_EQUAL(ids.d_protoBufData->d_deviceName, deviceName);
   BOOST_CHECK_EQUAL(ids.d_protoBufData->d_requestorID, requestorID);
+
+  /* no frontend yet */
+  BOOST_CHECK(dnsdist_ffi_dnsquestion_get_incoming_interface(nullptr) == nullptr);
+  BOOST_CHECK(dnsdist_ffi_dnsquestion_get_incoming_interface(&lightDQ) == nullptr);
+  {
+    /* frontend without and interface set */
+    const std::string interface{};
+    ClientState frontend(ids.origDest, false, false, 0, interface, {}, false);
+    ids.cs = &frontend;
+    const auto* itfPtr = dnsdist_ffi_dnsquestion_get_incoming_interface(&lightDQ);
+    BOOST_REQUIRE(itfPtr != nullptr);
+    BOOST_CHECK_EQUAL(std::string(itfPtr), interface);
+    ids.cs = nullptr;
+  }
+  {
+    /* frontend with interface set */
+    const std::string interface{"interface-name-0"};
+    ClientState frontend(ids.origDest, false, false, 0, interface, {}, false);
+    ids.cs = &frontend;
+    const auto* itfPtr = dnsdist_ffi_dnsquestion_get_incoming_interface(&lightDQ);
+    BOOST_REQUIRE(itfPtr != nullptr);
+    BOOST_CHECK_EQUAL(std::string(itfPtr), interface);
+    ids.cs = nullptr;
+  }
 }
 
 BOOST_AUTO_TEST_CASE(test_Response)
diff --git a/regression-tests.dnsdist/test_IncomingInterface.py b/regression-tests.dnsdist/test_IncomingInterface.py
new file mode 100644 (file)
index 0000000..e02d3f2
--- /dev/null
@@ -0,0 +1,117 @@
+#!/usr/bin/env python
+import socket
+import unittest
+import dns
+from dnsdisttests import DNSDistTest
+
+def get_loopback_itf():
+    interfaces = socket.if_nameindex()
+    for itf in interfaces:
+        if itf[1] == 'lo':
+            return 'lo'
+    return None
+
+class TestIncomingInterface(DNSDistTest):
+    _lo_itf = get_loopback_itf()
+    _config_template = """
+    local itfName = '%s'
+    addLocal('127.0.0.1:%d', {interface=itfName})
+
+    function checkItf(dq)
+      if dq:getIncomingInterface() ~= itfName then
+        return DNSAction.Spoof, '1.2.3.4'
+      end
+      return DNSAction.None
+    end
+
+    function checkItfResponse(dr)
+      if dr:getIncomingInterface() ~= itfName then
+        return DNSResponseAction.ServFail
+      end
+      return DNSResponseAction.None
+    end
+
+    addAction(AllRule(), LuaAction(checkItf))
+    addResponseAction(AllRule(), LuaResponseAction(checkItfResponse))
+    newServer{address="127.0.0.1:%d"}
+    """
+    _config_params = ['_lo_itf', '_dnsDistPort', '_testServerPort']
+    _skipListeningOnCL = True
+
+    def testItfName(self):
+        """
+        Advanced: Check incoming interface name
+        """
+        if get_loopback_itf() is None:
+            raise unittest.SkipTest('No lo interface')
+
+        name = 'incoming-interface.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '4.3.2.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEqual(receivedQuery, query)
+            self.assertEqual(receivedResponse, response)
+
+class TestIncomingInterfaceNotSet(DNSDistTest):
+    _lo_itf = get_loopback_itf()
+    _config_template = """
+    local itfName = '%s'
+    addLocal('127.0.0.1:%d')
+
+    function checkItf(dq)
+      if dq:getIncomingInterface() ~= itfName then
+        return DNSAction.Spoof, '1.2.3.4'
+      end
+      return DNSAction.None
+    end
+
+    function checkItfResponse(dr)
+      if dr:getIncomingInterface() ~= itfName then
+        return DNSResponseAction.ServFail
+      end
+      return DNSResponseAction.None
+    end
+
+    addAction(AllRule(), LuaAction(checkItf))
+    addResponseAction(AllRule(), LuaResponseAction(checkItfResponse))
+    newServer{address="127.0.0.1:%d"}
+    """
+    _config_params = ['_lo_itf', '_dnsDistPort', '_testServerPort']
+    _skipListeningOnCL = True
+
+    def testItfName(self):
+        """
+        Advanced: Check incoming interface name (not set)
+        """
+        if get_loopback_itf() is None:
+            raise unittest.SkipTest('No lo interface')
+
+        name = 'incoming-interface-not-set.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '1.2.3.4')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEqual(receivedQuery, None)
+            self.assertEqual(receivedResponse, expectedResponse)