]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add Proxy Protocol v2 support to `TeeAction`
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 17 Nov 2023 15:06:54 +0000 (16:06 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 17 Nov 2023 15:08:15 +0000 (16:08 +0100)
pdns/dnsdist-lua-actions.cc
pdns/dnsdistdist/docs/rules-actions.rst
regression-tests.dnsdist/proxyprotocolutils.py [new file with mode: 0644]
regression-tests.dnsdist/test_ProxyProtocol.py
regression-tests.dnsdist/test_TeeAction.py

index 575c8f3465b98cd933aa1247addef6650ac592ea..4463ffd5f80b885518eb18125ba5ba028e661f90 100644 (file)
@@ -30,6 +30,7 @@
 #include "dnsdist-lua-ffi.hh"
 #include "dnsdist-mac-address.hh"
 #include "dnsdist-protobuf.hh"
+#include "dnsdist-proxy-protocol.hh"
 #include "dnsdist-kvs.hh"
 #include "dnsdist-svc.hh"
 
@@ -130,18 +131,18 @@ class TeeAction : public DNSAction
 {
 public:
   // this action does not stop the processing
-  TeeAction(const ComboAddress& rca, const boost::optional<ComboAddress>& lca, bool addECS=false);
+  TeeAction(const ComboAddress& rca, const boost::optional<ComboAddress>& lca, bool addECS = false, bool addProxyProtocol = false);
   ~TeeAction() override;
   DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override;
   std::string toString() const override;
   std::map<std::string, double> getStats() const override;
 
 private:
-  ComboAddress d_remote;
-  std::thread d_worker;
   void worker();
 
-  int d_fd{-1};
+  ComboAddress d_remote;
+  std::thread d_worker;
+  Socket d_socket;
   mutable std::atomic<unsigned long> d_senderrors{0};
   unsigned long d_recverrors{0};
   mutable std::atomic<unsigned long> d_queries{0};
@@ -156,32 +157,26 @@ private:
   stat_t d_otherrcode{0};
   std::atomic<bool> d_pleaseQuit{false};
   bool d_addECS{false};
+  bool d_addProxyProtocol{false};
 };
 
-TeeAction::TeeAction(const ComboAddress& rca, const boost::optional<ComboAddress>& lca, bool addECS)
-  : d_remote(rca), d_addECS(addECS)
+TeeAction::TeeAction(const ComboAddress& rca, const boost::optional<ComboAddress>& lca, bool addECS, bool addProxyProtocol)
+  : d_remote(rca), d_socket(d_remote.sin4.sin_family, SOCK_DGRAM, 0), d_addECS(addECS), d_addProxyProtocol(addProxyProtocol)
 {
-  d_fd=SSocket(d_remote.sin4.sin_family, SOCK_DGRAM, 0);
-  try {
-    if (lca) {
-      SBind(d_fd, *lca);
-    }
-    SConnect(d_fd, d_remote);
-    setNonBlocking(d_fd);
-    d_worker=std::thread([this](){worker();});
-  }
-  catch (...) {
-    if (d_fd != -1) {
-      close(d_fd);
-    }
-    throw;
+  if (lca) {
+    d_socket.bind(*lca, false);
   }
+  d_socket.connect(d_remote);
+  d_socket.setNonBlocking();
+  d_worker = std::thread([this]() {
+    worker();
+  });
 }
 
 TeeAction::~TeeAction()
 {
-  d_pleaseQuit=true;
-  close(d_fd);
+  d_pleaseQuit = true;
+  close(d_socket.releaseHandle());
   d_worker.join();
 }
 
@@ -189,28 +184,38 @@ DNSAction::Action TeeAction::operator()(DNSQuestion* dq, std::string* ruleresult
 {
   if (dq->overTCP()) {
     d_tcpdrops++;
+    return DNSAction::Action::None;
   }
-  else {
-    ssize_t res;
-    d_queries++;
 
-    if(d_addECS) {
-      PacketBuffer query(dq->getData());
-      bool ednsAdded = false;
-      bool ecsAdded = false;
+  d_queries++;
 
-      std::string newECSOption;
-      generateECSOption(dq->ecs ? dq->ecs->getNetwork() : dq->ids.origRemote, newECSOption, dq->ecs ? dq->ecs->getBits() : dq->ecsPrefixLength);
+  PacketBuffer query;
+  if (d_addECS) {
+    query = dq->getData();
+    bool ednsAdded = false;
+    bool ecsAdded = false;
 
-      if (!handleEDNSClientSubnet(query, dq->getMaximumSize(), dq->ids.qname.wirelength(), ednsAdded, ecsAdded, dq->ecsOverride, newECSOption)) {
-        return DNSAction::Action::None;
-      }
+    std::string newECSOption;
+    generateECSOption(dq->ecs ? dq->ecs->getNetwork() : dq->ids.origRemote, newECSOption, dq->ecs ? dq->ecs->getBits() : dq->ecsPrefixLength);
 
-      res = send(d_fd, query.data(), query.size(), 0);
+    if (!handleEDNSClientSubnet(query, dq->getMaximumSize(), dq->ids.qname.wirelength(), ednsAdded, ecsAdded, dq->ecsOverride, newECSOption)) {
+      return DNSAction::Action::None;
     }
-    else {
-      res = send(d_fd, dq->getData().data(), dq->getData().size(), 0);
+  }
+
+  if (d_addProxyProtocol) {
+    auto proxyPayload = getProxyProtocolPayload(*dq);
+    if (query.empty()) {
+      query = dq->getData();
+    }
+    if (!addProxyProtocol(query, proxyPayload)) {
+      return DNSAction::Action::None;
     }
+  }
+
+  {
+    const PacketBuffer& payload = query.empty() ? dq->getData() : query;
+    auto res = send(d_socket.getHandle(), payload.data(), payload.size(), 0);
 
     if (res <= 0) {
       d_senderrors++;
@@ -222,7 +227,7 @@ DNSAction::Action TeeAction::operator()(DNSQuestion* dq, std::string* ruleresult
 
 std::string TeeAction::toString() const
 {
-  return "tee to "+d_remote.toStringWithPort();
+  return "tee to " + d_remote.toStringWithPort();
 }
 
 std::map<std::string,double> TeeAction::getStats() const
@@ -247,7 +252,7 @@ void TeeAction::worker()
   ssize_t res = 0;
   const dnsheader_aligned dh(packet.data());
   for (;;) {
-    res = waitForData(d_fd, 0, 250000);
+    res = waitForData(d_socket.getHandle(), 0, 250000);
     if (d_pleaseQuit) {
       break;
     }
@@ -259,7 +264,7 @@ void TeeAction::worker()
     if (res == 0) {
       continue;
     }
-    res = recv(d_fd, packet.data(), packet.size(), 0);
+    res = recv(d_socket.getHandle(), packet.data(), packet.size(), 0);
     if (static_cast<size_t>(res) <= sizeof(struct dnsheader)) {
       d_recverrors++;
     }
@@ -2739,13 +2744,13 @@ void setupLuaActions(LuaContext& luaCtx)
     });
 #endif /* DISABLE_PROTOBUF */
 
-  luaCtx.writeFunction("TeeAction", [](const std::string& remote, boost::optional<bool> addECS, boost::optional<std::string> local) {
+  luaCtx.writeFunction("TeeAction", [](const std::string& remote, boost::optional<bool> addECS, boost::optional<std::string> local, boost::optional<bool> addProxyProtocol) {
       boost::optional<ComboAddress> localAddr{boost::none};
       if (local) {
         localAddr = ComboAddress(*local, 0);
       }
 
-      return std::shared_ptr<DNSAction>(new TeeAction(ComboAddress(remote, 53), localAddr, addECS ? *addECS : false));
+      return std::shared_ptr<DNSAction>(new TeeAction(ComboAddress(remote, 53), localAddr, addECS ? *addECS : false, addProxyProtocol ? *addProxyProtocol : false));
     });
 
   luaCtx.writeFunction("SetECSPrefixLengthAction", [](uint16_t v4PrefixLength, uint16_t v6PrefixLength) {
index 090b5a7701782e51e75a237ffb53db15de602a5a..8ba254b36933e66ff0f7cca856bf56e314ea2535 100644 (file)
@@ -1850,18 +1850,24 @@ The following actions exist.
   Before 1.7.0 this action was performed even when the query had been received over TCP, which required the use of :func:`TCPRule` to
   prevent the TC bit from being set over TCP transports.
 
-.. function:: TeeAction(remote[, addECS[, local]])
+.. function:: TeeAction(remote[, addECS[, local [, addProxyProtocol]]])
 
   .. versionchanged:: 1.8.0
     Added the optional parameter ``local``.
 
+  .. versionchanged:: 1.9.0
+    Added the optional parameter ``addProxyProtocol``.
+
   Send copy of query to ``remote``, keep stats on responses.
   If ``addECS`` is set to true, EDNS Client Subnet information will be added to the query.
+  If ``addProxyProtocol`` is set to true, a Proxy Protocol v2 payload will be prepended in front of the query.
   If ``local`` has provided a value like "192.0.2.53", :program:`dnsdist` will try binding that address as local address when sending the queries.
   Subsequent rules are processed after this action.
 
   :param string remote: An IP:PORT combination to send the copied queries to
-  :param bool addECS: Whether or not to add ECS information. Default false
+  :param bool addECS: Whether to add ECS information. Default false.
+  :param str local: The local address to use to send queries. The default is to let the kernel pick one.
+  :param bool addProxyProtocol: Whether to prepend a proxy protocol v2 payload in front of the query. Default to false.
 
 .. function:: TempFailureCacheTTLAction(ttl)
 
diff --git a/regression-tests.dnsdist/proxyprotocolutils.py b/regression-tests.dnsdist/proxyprotocolutils.py
new file mode 100644 (file)
index 0000000..9d00219
--- /dev/null
@@ -0,0 +1,120 @@
+#!/usr/bin/env python
+import copy
+import dns
+import socket
+import struct
+import sys
+
+from proxyprotocol import ProxyProtocol
+
+def ProxyProtocolUDPResponder(port, fromQueue, toQueue):
+    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+    try:
+        sock.bind(("127.0.0.1", port))
+    except socket.error as e:
+        print("Error binding in the Proxy Protocol UDP responder: %s" % str(e))
+        sys.exit(1)
+
+    while True:
+        data, addr = sock.recvfrom(4096)
+
+        proxy = ProxyProtocol()
+        if len(data) < proxy.HEADER_SIZE:
+            continue
+
+        if not proxy.parseHeader(data):
+            continue
+
+        if proxy.local:
+            # likely a healthcheck
+            data = data[proxy.HEADER_SIZE:]
+            request = dns.message.from_wire(data)
+            response = dns.message.make_response(request)
+            wire = response.to_wire()
+            sock.settimeout(2.0)
+            sock.sendto(wire, addr)
+            sock.settimeout(None)
+
+            continue
+
+        payload = data[:(proxy.HEADER_SIZE + proxy.contentLen)]
+        dnsData = data[(proxy.HEADER_SIZE + proxy.contentLen):]
+        toQueue.put([payload, dnsData], True, 2.0)
+        # computing the correct ID for the response
+        request = dns.message.from_wire(dnsData)
+        response = fromQueue.get(True, 2.0)
+        response.id = request.id
+
+        sock.settimeout(2.0)
+        sock.sendto(response.to_wire(), addr)
+        sock.settimeout(None)
+
+    sock.close()
+
+def ProxyProtocolTCPResponder(port, fromQueue, toQueue):
+    # be aware that this responder will not accept a new connection
+    # until the last one has been closed. This is done on purpose to
+    # to check for connection reuse, making sure that a lot of connections
+    # are not opened in parallel.
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+    sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+    try:
+        sock.bind(("127.0.0.1", port))
+    except socket.error as e:
+        print("Error binding in the TCP responder: %s" % str(e))
+        sys.exit(1)
+
+    sock.listen(100)
+    while True:
+        (conn, _) = sock.accept()
+        conn.settimeout(5.0)
+        # try to read the entire Proxy Protocol header
+        proxy = ProxyProtocol()
+        header = conn.recv(proxy.HEADER_SIZE)
+        if not header:
+            conn.close()
+            continue
+
+        if not proxy.parseHeader(header):
+            conn.close()
+            continue
+
+        proxyContent = conn.recv(proxy.contentLen)
+        if not proxyContent:
+            conn.close()
+            continue
+
+        payload = header + proxyContent
+        while True:
+          try:
+            data = conn.recv(2)
+          except socket.timeout:
+            data = None
+
+          if not data:
+            conn.close()
+            break
+
+          (datalen,) = struct.unpack("!H", data)
+          data = conn.recv(datalen)
+
+          toQueue.put([payload, data], True, 2.0)
+
+          response = copy.deepcopy(fromQueue.get(True, 2.0))
+          if not response:
+            conn.close()
+            break
+
+          # computing the correct ID for the response
+          request = dns.message.from_wire(data)
+          response.id = request.id
+
+          wire = response.to_wire()
+          conn.send(struct.pack("!H", len(wire)))
+          conn.send(wire)
+
+        conn.close()
+
+    sock.close()
index dd4ca4fbefb9dc010d537ae3e1f6a53c33a3e889..a29852c8b6c704780ecf27373b34487ddb529c40 100644 (file)
@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 
-import copy
 import dns
 import selectors
 import socket
@@ -12,6 +11,7 @@ import time
 
 from dnsdisttests import DNSDistTest, pickAvailablePort
 from proxyprotocol import ProxyProtocol
+from proxyprotocolutils import ProxyProtocolUDPResponder, ProxyProtocolTCPResponder
 from dnsdistdohtests import DNSDistDOHTest
 
 # Python2/3 compatibility hacks
@@ -20,118 +20,6 @@ try:
 except ImportError:
   from Queue import Queue
 
-def ProxyProtocolUDPResponder(port, fromQueue, toQueue):
-    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
-    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
-    try:
-        sock.bind(("127.0.0.1", port))
-    except socket.error as e:
-        print("Error binding in the Proxy Protocol UDP responder: %s" % str(e))
-        sys.exit(1)
-
-    while True:
-        data, addr = sock.recvfrom(4096)
-
-        proxy = ProxyProtocol()
-        if len(data) < proxy.HEADER_SIZE:
-            continue
-
-        if not proxy.parseHeader(data):
-            continue
-
-        if proxy.local:
-            # likely a healthcheck
-            data = data[proxy.HEADER_SIZE:]
-            request = dns.message.from_wire(data)
-            response = dns.message.make_response(request)
-            wire = response.to_wire()
-            sock.settimeout(2.0)
-            sock.sendto(wire, addr)
-            sock.settimeout(None)
-
-            continue
-
-        payload = data[:(proxy.HEADER_SIZE + proxy.contentLen)]
-        dnsData = data[(proxy.HEADER_SIZE + proxy.contentLen):]
-        toQueue.put([payload, dnsData], True, 2.0)
-        # computing the correct ID for the response
-        request = dns.message.from_wire(dnsData)
-        response = fromQueue.get(True, 2.0)
-        response.id = request.id
-
-        sock.settimeout(2.0)
-        sock.sendto(response.to_wire(), addr)
-        sock.settimeout(None)
-
-    sock.close()
-
-def ProxyProtocolTCPResponder(port, fromQueue, toQueue):
-    # be aware that this responder will not accept a new connection
-    # until the last one has been closed. This is done on purpose to
-    # to check for connection reuse, making sure that a lot of connections
-    # are not opened in parallel.
-    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
-    sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
-    try:
-        sock.bind(("127.0.0.1", port))
-    except socket.error as e:
-        print("Error binding in the TCP responder: %s" % str(e))
-        sys.exit(1)
-
-    sock.listen(100)
-    while True:
-        (conn, _) = sock.accept()
-        conn.settimeout(5.0)
-        # try to read the entire Proxy Protocol header
-        proxy = ProxyProtocol()
-        header = conn.recv(proxy.HEADER_SIZE)
-        if not header:
-            conn.close()
-            continue
-
-        if not proxy.parseHeader(header):
-            conn.close()
-            continue
-
-        proxyContent = conn.recv(proxy.contentLen)
-        if not proxyContent:
-            conn.close()
-            continue
-
-        payload = header + proxyContent
-        while True:
-          try:
-            data = conn.recv(2)
-          except socket.timeout:
-            data = None
-
-          if not data:
-            conn.close()
-            break
-
-          (datalen,) = struct.unpack("!H", data)
-          data = conn.recv(datalen)
-
-          toQueue.put([payload, data], True, 2.0)
-
-          response = copy.deepcopy(fromQueue.get(True, 2.0))
-          if not response:
-            conn.close()
-            break
-
-          # computing the correct ID for the response
-          request = dns.message.from_wire(data)
-          response.id = request.id
-
-          wire = response.to_wire()
-          conn.send(struct.pack("!H", len(wire)))
-          conn.send(wire)
-
-        conn.close()
-
-    sock.close()
-
 toProxyQueue = Queue()
 fromProxyQueue = Queue()
 proxyResponderPort = pickAvailablePort()
index 373d90f6454327ad9678e254250b833565e0e9f1..0516fbc505c7d98a31b78bad59c5c9a14ad8f37f 100644 (file)
@@ -4,22 +4,27 @@ import threading
 import clientsubnetoption
 import dns
 from dnsdisttests import DNSDistTest, Queue, pickAvailablePort
+from proxyprotocolutils import ProxyProtocolUDPResponder, ProxyProtocolTCPResponder
 
 class TestTeeAction(DNSDistTest):
 
     _consoleKey = DNSDistTest.generateConsoleKey()
     _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
     _teeServerPort = pickAvailablePort()
+    _teeProxyServerPort = pickAvailablePort()
     _toTeeQueue = Queue()
     _fromTeeQueue = Queue()
+    _toTeeProxyQueue = Queue()
+    _fromTeeProxyQueue = Queue()
     _config_template = """
     setKey("%s")
     controlSocket("127.0.0.1:%s")
     newServer{address="127.0.0.1:%d"}
     addAction(QTypeRule(DNSQType.A), TeeAction("127.0.0.1:%d", true))
     addAction(QTypeRule(DNSQType.AAAA), TeeAction("127.0.0.1:%d", false))
+    addAction(QTypeRule(DNSQType.ANY), TeeAction("127.0.0.1:%d", false, '127.0.0.1', true))
     """
-    _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_teeServerPort', '_teeServerPort']
+    _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_teeServerPort', '_teeServerPort', '_teeProxyServerPort']
     @classmethod
     def startResponders(cls):
         print("Launching responders..")
@@ -36,6 +41,10 @@ class TestTeeAction(DNSDistTest):
         cls._TeeResponder.daemon = True
         cls._TeeResponder.start()
 
+        cls._TeeProxyResponder = threading.Thread(name='Proxy Protocol Tee Responder', target=ProxyProtocolUDPResponder, args=[cls._teeProxyServerPort, cls._toTeeProxyQueue, cls._fromTeeProxyQueue])
+        cls._TeeProxyResponder.daemon = True
+        cls._TeeProxyResponder.start()
+
     def testTeeWithECS(self):
         """
         TeeAction: ECS
@@ -130,4 +139,50 @@ responses\t%d
 send-errors\t0
 servfails\t0
 tcp-drops\t0
+""" % (numberOfQueries, numberOfQueries, numberOfQueries))
+
+    def testTeeWithProxy(self):
+        """
+        TeeAction: Proxy
+        """
+        name = 'proxy.tee.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'ANY', 'IN')
+        response = dns.message.make_response(query)
+
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        numberOfQueries = 10
+        for _ in range(numberOfQueries):
+            # push the response to the Tee Proxy server
+            self._toTeeProxyQueue.put(response, True, 2.0)
+
+            (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+
+            # retrieve the query from the Tee Proxy server
+            [payload, teedQuery] = self._fromTeeProxyQueue.get(True, 2.0)
+            self.checkMessageNoEDNS(query, dns.message.from_wire(teedQuery))
+            self.checkMessageProxyProtocol(payload, '127.0.0.1', '127.0.0.1', False)
+
+        # check the TeeAction stats
+        stats = self.sendConsoleCommand("getAction(0):printStats()")
+        self.assertEqual(stats, """noerrors\t%d
+nxdomains\t0
+other-rcode\t0
+queries\t%d
+recv-errors\t0
+refuseds\t0
+responses\t%d
+send-errors\t0
+servfails\t0
+tcp-drops\t0
 """ % (numberOfQueries, numberOfQueries, numberOfQueries))