]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/mplexer.hh
mplexer: Keep TTD ordered so we can scan for timeouts efficiently
[thirdparty/pdns.git] / pdns / mplexer.hh
index b42e90092847f8dcb35bc80032ca939b68a17652..a008ec7cf7983b917f77650a4047e1f3eccb89e0 100644 (file)
 #include <boost/shared_array.hpp>
 #include <boost/tuple/tuple.hpp>
 #include <boost/tuple/tuple_comparison.hpp>
+#include <boost/multi_index_container.hpp>
+#include <boost/multi_index/ordered_index.hpp>
+#include <boost/multi_index/hashed_index.hpp>
+#include <boost/multi_index/key_extractors.hpp>
 #include <vector>
 #include <map>
 #include <stdexcept>
 #include <string>
 #include <sys/time.h>
 
+using namespace ::boost::multi_index;
+
 class FDMultiplexerException : public std::runtime_error
 {
 public:
@@ -57,8 +63,9 @@ protected:
   struct Callback
   {
     callbackfunc_t d_callback;
-    funcparam_t d_parameter;
+    mutable funcparam_t d_parameter;
     struct timeval d_ttd;
+    int d_fd;
   };
 
 public:
@@ -109,8 +116,10 @@ public:
       throw FDMultiplexerException("attempt to timestamp fd not in the multiplexer");
     }
 
+    auto newEntry = *it;
     tv.tv_sec += timeout;
-    it->second.d_ttd = tv;
+    newEntry.d_ttd = tv;
+    d_readCallbacks.replace(it, newEntry);
   }
 
   virtual void setWriteTTD(int fd, struct timeval tv, int timeout)
@@ -120,29 +129,23 @@ public:
       throw FDMultiplexerException("attempt to timestamp fd not in the multiplexer");
     }
 
+    auto newEntry = *it;
     tv.tv_sec += timeout;
-    it->second.d_ttd = tv;
-  }
-
-  virtual funcparam_t& getReadParameter(int fd) 
-  {
-    const auto& it = d_readCallbacks.find(fd);
-    if(it == d_readCallbacks.end()) {
-      throw FDMultiplexerException("attempt to look up data in multiplexer for unlisted fd "+std::to_string(fd));
-    }
-
-    return it->second.d_parameter;
+    newEntry.d_ttd = tv;
+    d_writeCallbacks.replace(it, newEntry);
   }
 
   virtual std::vector<std::pair<int, funcparam_t> > getTimeouts(const struct timeval& tv, bool writes=false)
   {
-    const auto tied = boost::tie(tv.tv_sec, tv.tv_usec);
     std::vector<std::pair<int, funcparam_t> > ret;
+    const auto tied = boost::tie(tv.tv_sec, tv.tv_usec);
+    auto& idx = writes ? d_writeCallbacks.get<TTDOrderedTag>() : d_readCallbacks.get<TTDOrderedTag>();
 
-    for(const auto& entry : (writes ? d_writeCallbacks : d_readCallbacks)) {
-      if(entry.second.d_ttd.tv_sec && tied > boost::tie(entry.second.d_ttd.tv_sec, entry.second.d_ttd.tv_usec)) {
-        ret.push_back(std::make_pair(entry.first, entry.second.d_parameter));
+    for (auto it = idx.begin(); it != idx.end(); ++it) {
+      if (it->d_ttd.tv_sec == 0 || tied <= boost::tie(it->d_ttd.tv_sec, it->d_ttd.tv_usec)) {
+        break;
       }
+      ret.push_back(std::make_pair(it->d_fd, it->d_parameter));
     }
 
     return ret;
@@ -160,7 +163,42 @@ public:
   virtual std::string getName() const = 0;
 
 protected:
-  typedef std::map<int, Callback> callbackmap_t;
+  struct FDBasedTag {};
+  struct TTDOrderedTag {};
+  struct ttd_compare
+  {
+    /* we want a 0 TTD (no timeout) to come _after_ everything else */
+    bool operator() (const struct timeval& lhs, const struct timeval& rhs) const
+    {
+      /* special treatment if at least one of the TTD is 0,
+         normal comparison otherwise */
+      if (lhs.tv_sec == 0 && rhs.tv_sec == 0) {
+        return false;
+      }
+      if (lhs.tv_sec == 0 && rhs.tv_sec != 0) {
+        return false;
+      }
+      if (lhs.tv_sec != 0 && rhs.tv_sec == 0) {
+        return true;
+      }
+
+      return std::tie(lhs.tv_sec, lhs.tv_usec) < std::tie(rhs.tv_sec, rhs.tv_usec);
+    }
+  };
+
+  typedef multi_index_container<
+    Callback,
+    indexed_by <
+                hashed_unique<tag<FDBasedTag>,
+                              member<Callback,int,&Callback::d_fd>
+                              >,
+                ordered_non_unique<tag<TTDOrderedTag>,
+                                   member<Callback,struct timeval,&Callback::d_ttd>,
+                                   ttd_compare
+                                   >
+               >
+  > callbackmap_t;
+
   callbackmap_t d_readCallbacks, d_writeCallbacks;
 
   virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr)=0;
@@ -171,6 +209,7 @@ protected:
   void accountingAddFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr)
   {
     Callback cb;
+    cb.d_fd = fd;
     cb.d_callback=toDo;
     cb.d_parameter=parameter;
     memset(&cb.d_ttd, 0, sizeof(cb.d_ttd));
@@ -178,7 +217,7 @@ protected:
       cb.d_ttd = *ttd;
     }
 
-    auto pair = cbmap.insert({fd, cb});
+    auto pair = cbmap.insert(cb);
     if (!pair.second) {
       throw FDMultiplexerException("Tried to add fd "+std::to_string(fd)+ " to multiplexer twice");
     }