]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add TCP management options from rfc7766 section 10 4611/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 21 Oct 2016 09:22:23 +0000 (11:22 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 14 Dec 2016 14:39:56 +0000 (15:39 +0100)
pdns/README-dnsdist.md
pdns/dnsdist-console.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.hh
pdns/misc.cc
pdns/misc.hh
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_AXFR.py
regression-tests.dnsdist/test_TCPLimits.py [new file with mode: 0644]

index 263914d37d17ba2c199935836fb8294567220c36..f90459ea36870fd5c903368360c16afe9288a6f4 100644 (file)
@@ -1488,6 +1488,9 @@ instantiate a server with additional parameters
         * member `add(DNSName)`: add this DNSName to the node
  * Tuning related:
     * `setMaxTCPClientThreads(n)`: set the maximum of TCP client threads, handling TCP connections
+    * `setMaxTCPConnectionDuration(n)`: set the maximum duration of an incoming TCP connection, in seconds. 0 (the default) means unlimited
+    * `setMaxTCPConnectionsPerClient(n)`: set the maximum number of TCP connections per client. 0 (the default) means unlimited
+    * `setMaxTCPQueriesPerConnection(n)`: set the maximum number of queries in an incoming TCP connection. 0 (the default) means unlimited
     * `setMaxTCPQueuedConnections(n)`: set the maximum number of TCP connections queued (waiting to be picked up by a client thread), defaults to 1000. 0 means unlimited
     * `setMaxUDPOutstanding(n)`: set the maximum number of outstanding UDP queries to a given backend server. This can only be set at configuration time and defaults to 10240
     * `setCacheCleaningDelay(n)`: set the interval in seconds between two runs of the cache cleaning algorithm, removing expired entries
index 45052f08db4997a28e8a44a9006a86186a6473b1..6a1c7fc39c1a9c599ef0723060b902102b70cc10 100644 (file)
@@ -332,6 +332,9 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "setKey", true, "key", "set access key to that key" },
   { "setLocal", true, "netmask, [true], [false], [TCP Fast Open queue size]", "reset list of addresses we listen on to this address. Second optional parameter sets TCP or not. Third optional parameter sets SO_REUSEPORT when available. Last parameter sets the TCP Fast Open queue size, enabling TCP Fast Open when available and the value is larger than 0." },
   { "setMaxTCPClientThreads", true, "n", "set the maximum of TCP client threads, handling TCP connections" },
+  { "setMaxTCPConnectionDuration", true, "n", "set the maximum duration of an incoming TCP connection, in seconds. 0 means unlimited" },
+  { "setMaxTCPConnectionsPerClient", true, "n", "set the maximum number of TCP connections per client. 0 means unlimited" },
+  { "setMaxTCPQueriesPerConnection", true, "n", "set the maximum number of queries in an incoming TCP connection. 0 means unlimited" },
   { "setMaxTCPQueuedConnections", true, "n", "set the maximum number of TCP connections queued (waiting to be picked up by a client thread)" },
   { "setMaxUDPOutstanding", true, "n", "set the maximum number of outstanding UDP queries to a given backend server. This can only be set at configuration time and defaults to 10240" },
   { "setQueryCount", true, "bool", "set whether queries should be counted" },
index 3337bf1d0f6e0d3d117f494c6e56bb9421faa746..a6386a606deff42abbcfc76e6e85a9bf906ce399 100644 (file)
@@ -1520,6 +1520,30 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
       }
     });
 
+  g_lua.writeFunction("setMaxTCPQueriesPerConnection", [](size_t max) {
+      if (!g_configurationDone) {
+        g_maxTCPQueriesPerConn = max;
+      } else {
+        g_outputBuffer="The maximum number of queries per TCP connection cannot be altered at runtime!\n";
+      }
+    });
+
+  g_lua.writeFunction("setMaxTCPConnectionsPerClient", [](size_t max) {
+      if (!g_configurationDone) {
+        g_maxTCPConnectionsPerClient = max;
+      } else {
+        g_outputBuffer="The maximum number of TCP connection per client cannot be altered at runtime!\n";
+      }
+    });
+
+  g_lua.writeFunction("setMaxTCPConnectionDuration", [](size_t max) {
+      if (!g_configurationDone) {
+        g_maxTCPConnectionDuration = max;
+      } else {
+        g_outputBuffer="The maximum duration of a TCP connection cannot be altered at runtime!\n";
+      }
+    });
+
   g_lua.writeFunction("showTCPStats", [] {
       setLuaNoSideEffect();
       boost::format fmt("%-10d %-10d %-10d %-10d\n");
index 69ea938785d278b986b2039543ec34a4cf8f8d5a..a60fb7552ca78916349af65c469bc4b60cc75246 100644 (file)
@@ -75,8 +75,25 @@ struct ConnectionInfo
 };
 
 uint64_t g_maxTCPQueuedConnections{1000};
+size_t g_maxTCPQueriesPerConn{0};
+size_t g_maxTCPConnectionDuration{0};
+size_t g_maxTCPConnectionsPerClient{0};
+static std::mutex tcpClientsCountMutex;
+static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
+
 void* tcpClientThread(int pipefd);
 
+static void decrementTCPClientCount(const ComboAddress& client)
+{
+  if (g_maxTCPConnectionsPerClient) {
+    std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
+    tcpClientsCount[client]--;
+    if (tcpClientsCount[client] == 0) {
+      tcpClientsCount.erase(client);
+    }
+  }
+}
+
 void TCPClientCollection::addTCPClientThread()
 {
   vinfolog("Adding TCP Client thread");
@@ -170,6 +187,18 @@ static bool sendResponseToClient(int fd, const char* response, uint16_t response
   return true;
 }
 
+static bool maxConnectionDurationReached(unsigned int maxConnectionDuration, time_t start, unsigned int& remainingTime)
+{
+  if (maxConnectionDuration) {
+    time_t elapsed = time(NULL) - start;
+    if (elapsed >= maxConnectionDuration) {
+      return true;
+    }
+    remainingTime = maxConnectionDuration - elapsed;
+  }
+  return false;
+}
+
 std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
 
 void* tcpClientThread(int pipefd)
@@ -220,6 +249,9 @@ void* tcpClientThread(int pipefd)
     memset(&dest, 0, sizeof(dest));
     dest.sin4.sin_family = ci.remote.sin4.sin_family;
     socklen_t len = dest.getSocklen();
+    size_t queriesCount = 0;
+    time_t connectionStartTime = time(NULL);
+
     if (!setNonBlocking(ci.fd))
       goto drop;
 
@@ -229,6 +261,7 @@ void* tcpClientThread(int pipefd)
 
     try {
       for(;;) {
+        unsigned int remainingTime = 0;
         ds = nullptr;
         outstanding = false;
 
@@ -238,6 +271,18 @@ void* tcpClientThread(int pipefd)
         ci.cs->queries++;
         g_stats.queries++;
 
+        queriesCount++;
+
+        if (g_maxTCPQueriesPerConn && queriesCount > g_maxTCPQueriesPerConn) {
+          vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", ci.remote.toStringWithPort(), queriesCount, g_maxTCPQueriesPerConn);
+          break;
+        }
+
+        if (maxConnectionDurationReached(g_maxTCPConnectionDuration, connectionStartTime, remainingTime)) {
+          vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", ci.remote.toStringWithPort());
+          break;
+        }
+
         if (qlen < sizeof(dnsheader)) {
           g_stats.nonCompliantQueries++;
           break;
@@ -251,7 +296,7 @@ void* tcpClientThread(int pipefd)
         size_t querySize = qlen <= 4096 ? qlen + 512 : qlen;
         char queryBuffer[querySize];
         const char* query = queryBuffer;
-        readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout);
+        readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout, remainingTime);
 
 #ifdef HAVE_DNSCRYPT
         std::shared_ptr<DnsCryptQuery> dnsCryptQuery = 0;
@@ -545,18 +590,18 @@ void* tcpClientThread(int pipefd)
       outstanding = false;
       --ds->outstanding;
     }
+    decrementTCPClientCount(ci.remote);
   }
   return 0;
 }
 
-
 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and 
    they will hand off to worker threads & spawn more of them if required
 */
 void* tcpAcceptorThread(void* p)
 {
   ClientState* cs = (ClientState*) p;
-
+  bool tcpClientCountIncremented = false;
   ComboAddress remote;
   remote.sin4.sin_family = cs->local.sin4.sin_family;
   
@@ -566,6 +611,7 @@ void* tcpAcceptorThread(void* p)
   for(;;) {
     bool queuedCounterIncremented = false;
     ConnectionInfo* ci = nullptr;
+    tcpClientCountIncremented = false;
     try {
       ci = new ConnectionInfo;
       ci->cs = cs;
@@ -589,8 +635,22 @@ void* tcpAcceptorThread(void* p)
         continue;
       }
 
+      if (g_maxTCPConnectionsPerClient) {
+        std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
+
+        if (tcpClientsCount[remote] >= g_maxTCPConnectionsPerClient) {
+          close(ci->fd);
+          delete ci;
+          ci=nullptr;
+          vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort());
+          continue;
+        }
+        tcpClientsCount[remote]++;
+        tcpClientCountIncremented = true;
+      }
+
       vinfolog("Got TCP connection from %s", remote.toStringWithPort());
-      
+
       ci->remote = remote;
       int pipe = g_tcpclientthreads->getThread();
       if (pipe >= 0) {
@@ -603,12 +663,18 @@ void* tcpAcceptorThread(void* p)
         close(ci->fd);
         delete ci;
         ci=nullptr;
+        if(tcpClientCountIncremented) {
+          decrementTCPClientCount(remote);
+        }
       }
     }
     catch(std::exception& e) {
       errlog("While reading a TCP question: %s", e.what());
       if(ci && ci->fd >= 0) 
        close(ci->fd);
+      if(tcpClientCountIncremented) {
+        decrementTCPClientCount(remote);
+      }
       delete ci;
       ci = nullptr;
       if (queuedCounterIncremented) {
index cab656521a9da762586c7046b527022887172253..eeb2a6d7d3e45ab049507c4a9e28f7e9fb282b59 100644 (file)
@@ -611,6 +611,9 @@ extern uint16_t g_maxOutstanding;
 extern std::atomic<bool> g_configurationDone;
 extern uint64_t g_maxTCPClientThreads;
 extern uint64_t g_maxTCPQueuedConnections;
+extern size_t g_maxTCPQueriesPerConn;
+extern size_t g_maxTCPConnectionDuration;
+extern size_t g_maxTCPConnectionsPerClient;
 extern std::atomic<uint16_t> g_cacheCleaningDelay;
 extern std::atomic<uint16_t> g_cacheCleaningPercentage;
 extern bool g_verboseHealthChecks;
index 038500276cc0286ddc1f9f37a8ffb84f4f17811c..2eca97d7d5e1280c14bd0086fbdeaff1b302cebf 100644 (file)
@@ -106,9 +106,15 @@ size_t readn2(int fd, void* buffer, size_t len)
   return len;
 }
 
-size_t readn2WithTimeout(int fd, void* buffer, size_t len, int timeout)
+size_t readn2WithTimeout(int fd, void* buffer, size_t len, int idleTimeout, int totalTimeout)
 {
   size_t pos = 0;
+  time_t start = 0;
+  int remainingTime = totalTimeout;
+  if (totalTimeout) {
+    start = time(NULL);
+  }
+
   do {
     ssize_t got = read(fd, (char *)buffer + pos, len - pos);
     if (got > 0) {
@@ -119,7 +125,7 @@ size_t readn2WithTimeout(int fd, void* buffer, size_t len, int timeout)
     }
     else {
       if (errno == EAGAIN) {
-        int res = waitForData(fd, timeout);
+        int res = waitForData(fd, (totalTimeout == 0 || idleTimeout <= remainingTime) ? idleTimeout : remainingTime);
         if (res > 0) {
           /* there is data available */
         }
@@ -133,6 +139,16 @@ size_t readn2WithTimeout(int fd, void* buffer, size_t len, int timeout)
         unixDie("failed in readn2WithTimeout");
       }
     }
+
+    if (totalTimeout) {
+      time_t now = time(NULL);
+      int elapsed = now - start;
+      if (elapsed >= remainingTime) {
+        throw runtime_error("Timeout while reading data");
+      }
+      start = now;
+      remainingTime -= elapsed;
+    }
   }
   while (pos < len);
 
index fa3104d55dd4e60f7c5ae329c31bd5f757162434..50caa8bf80c99fef08542c6ad788588bebf5bd7f 100644 (file)
@@ -147,7 +147,7 @@ vstringtok (Container &container, string const &in,
 size_t writen2(int fd, const void *buf, size_t count);
 inline size_t writen2(int fd, const std::string &s) { return writen2(fd, s.data(), s.size()); }
 size_t readn2(int fd, void* buffer, size_t len);
-size_t readn2WithTimeout(int fd, void* buffer, size_t len, int timeout);
+size_t readn2WithTimeout(int fd, void* buffer, size_t len, int idleTimeout, int totalTimeout=0);
 size_t writen2WithTimeout(int fd, const void * buffer, size_t len, int timeout);
 
 const string toLower(const string &upper);
index d04c247803db5428ec6a88a8a93112beb65ea2f3..f62aed9baaa33ffcb0bbd17a42a6f2048ff29270 100644 (file)
@@ -235,42 +235,57 @@ class DNSDistTest(unittest.TestCase):
         return (receivedQuery, message)
 
     @classmethod
-    def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
-        if useQueue:
-            cls._toResponderQueue.put(response, True, timeout)
+    def openTCPConnection(cls, timeout=None):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         if timeout:
             sock.settimeout(timeout)
 
         sock.connect(("127.0.0.1", cls._dnsDistPort))
+        return sock
 
-        try:
-            if not rawQuery:
-                wire = query.to_wire()
-            else:
-                wire = query
+    @classmethod
+    def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False):
+        if not rawQuery:
+            wire = query.to_wire()
+        else:
+            wire = query
 
-            sock.send(struct.pack("!H", len(wire)))
-            sock.send(wire)
-            data = sock.recv(2)
+        sock.send(struct.pack("!H", len(wire)))
+        sock.send(wire)
+
+    @classmethod
+    def recvTCPResponseOverConnection(cls, sock):
+        message = None
+        data = sock.recv(2)
+        if data:
+            (datalen,) = struct.unpack("!H", data)
+            data = sock.recv(datalen)
             if data:
-                (datalen,) = struct.unpack("!H", data)
-                data = sock.recv(datalen)
+                message = dns.message.from_wire(data)
+        return message
+
+    @classmethod
+    def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
+        message = None
+        if useQueue:
+            cls._toResponderQueue.put(response, True, timeout)
+
+        sock = cls.openTCPConnection(timeout)
+
+        try:
+            cls.sendTCPQueryOverConnection(sock, query, rawQuery)
+            message = cls.recvTCPResponseOverConnection(sock)
         except socket.timeout as e:
             print("Timeout: %s" % (str(e)))
-            data = None
         except socket.error as e:
             print("Network error: %s" % (str(e)))
-            data = None
         finally:
             sock.close()
 
         receivedQuery = None
-        message = None
         if useQueue and not cls._fromResponderQueue.empty():
             receivedQuery = cls._fromResponderQueue.get(True, timeout)
-        if data:
-            message = dns.message.from_wire(data)
+
         return (receivedQuery, message)
 
     @classmethod
index 6f54dc30218ec6cc5368312a04793350bbc55b07..9fed01195846d8c731838757608bfdc49016649c 100644 (file)
@@ -24,10 +24,6 @@ class TestAXFR(DNSDistTest):
         cls._TCPResponder.setDaemon(True)
         cls._TCPResponder.start()
 
-    _config_template = """
-    newServer{address="127.0.0.1:%s"}
-    """
-
     def testOneMessageAXFR(self):
         """
         AXFR: One message
diff --git a/regression-tests.dnsdist/test_TCPLimits.py b/regression-tests.dnsdist/test_TCPLimits.py
new file mode 100644 (file)
index 0000000..fb9dc03
--- /dev/null
@@ -0,0 +1,114 @@
+#!/usr/bin/env python
+import struct
+import time
+import dns
+from dnsdisttests import DNSDistTest
+
+class TestTCPLimits(DNSDistTest):
+
+    _tcpIdleTimeout = 2
+    _maxTCPQueriesPerConn = 5
+    _maxTCPConnsPerClient = 3
+    _maxTCPConnDuration = 5
+    _config_template = """
+    newServer{address="127.0.0.1:%s"}
+    setTCPRecvTimeout(%s)
+    setMaxTCPQueriesPerConnection(%s)
+    setMaxTCPConnectionsPerClient(%s)
+    setMaxTCPConnectionDuration(%s)
+    """
+    _config_params = ['_testServerPort', '_tcpIdleTimeout', '_maxTCPQueriesPerConn', '_maxTCPConnsPerClient', '_maxTCPConnDuration']
+
+    def testTCPQueriesPerConn(self):
+        """
+        TCP Limits: Maximum number of queries
+        """
+        name = 'maxqueriesperconn.tcp.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        conn = self.openTCPConnection()
+
+        count = 0
+        for idx in xrange(self._maxTCPQueriesPerConn):
+            try:
+                self.sendTCPQueryOverConnection(conn, query)
+                response = self.recvTCPResponseOverConnection(conn)
+                self.assertTrue(response)
+                count = count + 1
+            except:
+                pass
+
+        # this one should fail
+        failed = False
+        try:
+            self.sendTCPQueryOverConnection(conn, query)
+            response = self.recvTCPResponseOverConnection(conn)
+            self.assertFalse(response)
+            if not response:
+                failed = True
+            else:
+                count = count + 1
+        except:
+            failed = True
+
+        conn.close()
+        self.assertTrue(failed)
+        self.assertEqual(count, self._maxTCPQueriesPerConn)
+
+    def testTCPConnsPerClient(self):
+        """
+        TCP Limits: Maximum number of conns per client
+        """
+        name = 'maxconnsperclient.tcp.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        conns = []
+
+        for idx in xrange(self._maxTCPConnsPerClient + 1):
+            conns.append(self.openTCPConnection())
+
+        count = 0
+        failed = 0
+        for conn in conns:
+            try:
+                self.sendTCPQueryOverConnection(conn, query)
+                response = self.recvTCPResponseOverConnection(conn)
+                if response:
+                    count = count + 1
+                else:
+                    failed = failed + 1
+            except:
+                failed = failed + 1
+
+        for conn in conns:
+            conn.close()
+
+        self.assertEqual(count, self._maxTCPConnsPerClient)
+        self.assertEqual(failed, 1)
+
+    def testTCPDuration(self):
+        """
+        TCP Limits: Maximum duration
+        """
+        name = 'duration.tcp.tests.powerdns.com.'
+
+        start = time.time()
+        conn = self.openTCPConnection()
+        # immediately send the maximum size
+        conn.send(struct.pack("!H", 65535))
+
+        count = 0
+        while count < (self._maxTCPConnDuration * 2):
+            try:
+                # sleeping for only one second keeps us below the
+                # idle timeout (setTCPRecvTimeout())
+                time.sleep(1)
+                conn.send('A')
+                count = count + 1
+            except:
+                break
+
+        end = time.time()
+
+        self.assertAlmostEquals(count, self._maxTCPConnDuration, delta=2)
+        self.assertAlmostEquals(end - start, self._maxTCPConnDuration, delta=2)
+
+        conn.close()