]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Handle waiting for a descriptor to become readable OR writable
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 4 Aug 2021 12:35:53 +0000 (14:35 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 4 Aug 2021 12:35:53 +0000 (14:35 +0200)
This commit refactors our multiplexers to be able to wait for a
descriptor to become readable OR writable at the same time.
I kept the two separate maps for an easier handling of the separate
TTD and to limit the amount of changes, but we might want to merge
them into a single map in the future.
The accounting is moved into the parent class instead of being dealt
with by the multiplexers themselves.

I noticed that the poll multiplexer allocates and fills a vector of
pollfd for every call to run(), which seems wasteful, but I did not
want to touch that in this commit.

I did not compile or test the kqueue, ports and /dev/poll multiplexers
yet, so don't merge this without testing them first.

.not-formatted
pdns/devpollmplexer.cc
pdns/dnsdistdist/test-dnsdisttcp_cc.cc
pdns/epollmplexer.cc
pdns/kqueuemplexer.cc
pdns/mplexer.hh
pdns/pollmplexer.cc
pdns/portsmplexer.cc
pdns/test-mplexer.cc

index 3116793fff331ad0e78a11f98bd67efa9596f690..6a2ad40a943a1b8a8fb6bd075a6b4bb36ebcca5b 100644 (file)
@@ -39,7 +39,6 @@
 ./pdns/decafsigners.cc
 ./pdns/delaypipe.cc
 ./pdns/delaypipe.hh
-./pdns/devpollmplexer.cc
 ./pdns/digests.hh
 ./pdns/distributor.hh
 ./pdns/dns.cc
 ./pdns/ednspadding.cc
 ./pdns/ednssubnet.cc
 ./pdns/ednssubnet.hh
-./pdns/epollmplexer.cc
 ./pdns/filterpo.cc
 ./pdns/filterpo.hh
 ./pdns/fstrm_logger.cc
 ./pdns/ixplore.cc
 ./pdns/json.cc
 ./pdns/json.hh
-./pdns/kqueuemplexer.cc
 ./pdns/kvresp.cc
 ./pdns/lazy_allocator.hh
 ./pdns/libssl.cc
 ./pdns/minicurl.hh
 ./pdns/misc.cc
 ./pdns/misc.hh
-./pdns/mplexer.hh
 ./pdns/mtasker.cc
 ./pdns/mtasker.hh
 ./pdns/mtasker_context.hh
 ./pdns/pdnsutil.cc
 ./pdns/pkcs11signers.cc
 ./pdns/pkcs11signers.hh
-./pdns/pollmplexer.cc
-./pdns/portsmplexer.cc
 ./pdns/protozero.cc
 ./pdns/protozero.hh
 ./pdns/proxy-protocol.cc
 ./pdns/test-lock_hh.cc
 ./pdns/test-lua_auth4_cc.cc
 ./pdns/test-misc_hh.cc
-./pdns/test-mplexer.cc
 ./pdns/test-nameserver_cc.cc
 ./pdns/test-packetcache_cc.cc
 ./pdns/test-packetcache_hh.cc
index ff31b01819ca4bfe0cb0a366252aac0ba9d1497c..715206d775e3ab90ac49752152c4ca3d1ae23d58 100644 (file)
@@ -40,25 +40,28 @@ class DevPollFDMultiplexer : public FDMultiplexer
 {
 public:
   DevPollFDMultiplexer();
-  virtual ~DevPollFDMultiplexer()
+  ~DevPollFDMultiplexer()
   {
-    close(d_devpollfd);
+    if (d_devpollfd >= 0) {
+      close(d_devpollfd);
+    }
   }
 
-  virtual int run(struct timeval* tv, int timeout=500) override;
-  virtual void getAvailableFDs(std::vector<int>& fds, int timeout) override;
+  int run(struct timeval* tv, int timeout = 500) override;
+  void getAvailableFDs(std::vector<int>& fds, int timeout) override;
+
+  void addFD(int fd, FDMultiplexer::EventKind kind) override;
+  void removeFD(int fd, FDMultiplexer::EventKind kind) override;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr) override;
-  virtual void removeFD(callbackmap_t& cbmap, int fd) override;
   string getName() const override
   {
     return "/dev/poll";
   }
+
 private:
   int d_devpollfd;
 };
 
-
 static FDMultiplexer* makeDevPoll()
 {
   return new DevPollFDMultiplexer();
@@ -66,49 +69,53 @@ static FDMultiplexer* makeDevPoll()
 
 static struct DevPollRegisterOurselves
 {
-  DevPollRegisterOurselves() {
+  DevPollRegisterOurselves()
+  {
     FDMultiplexer::getMultiplexerMap().insert(make_pair(0, &makeDevPoll)); // priority 0!
   }
 } doItDevPoll;
 
-
-//int DevPollFDMultiplexer::s_maxevents=1024;
-DevPollFDMultiplexer::DevPollFDMultiplexer() 
+DevPollFDMultiplexer::DevPollFDMultiplexer()
 {
-  d_devpollfd=open("/dev/poll", O_RDWR);
-  if(d_devpollfd < 0)
-    throw FDMultiplexerException("Setting up /dev/poll: "+stringerror());
-    
+  d_devpollfd = open("/dev/poll", O_RDWR);
+  if (d_devpollfd < 0) {
+    throw FDMultiplexerException("Setting up /dev/poll: " + stringerror());
+  }
 }
 
-void DevPollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd)
+static int convertEventKind(FDMultiplexer::EventKind kind)
 {
-  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
+  switch (kind) {
+  case FDMultiplexer::EventKind::Read:
+    return POLLIN;
+  case FDMultiplexer::EventKind::Write:
+    return POLLOUT;
+  case FDMultiplexer::EventKind::Both:
+    return POLLIN | POLLOUT;
+  }
+}
 
+void DevPollFDMultiplexer::addFD(int fd, FDMultiplexer::EventKind kind)
+{
   struct pollfd devent;
-  devent.fd=fd;
-  devent.events= (&cbmap == &d_readCallbacks) ? POLLIN : POLLOUT;
+  devent.fd = fd;
+  devent.events = convertEventKind(kind);
   devent.revents = 0;
 
-  if(write(d_devpollfd, &devent, sizeof(devent)) != sizeof(devent)) {
-    cbmap.erase(fd);
-    throw FDMultiplexerException("Adding fd to /dev/poll/ set: "+stringerror());
+  if (write(d_devpollfd, &devent, sizeof(devent)) != sizeof(devent)) {
+    throw FDMultiplexerException("Adding fd to /dev/poll/ set: " + stringerror());
   }
 }
 
-void DevPollFDMultiplexer::removeFD(callbackmap_t& cbmap, int fd)
+void DevPollFDMultiplexer::removeFD(int fd, FDMultiplexer::EventKind)
 {
-  if(!cbmap.erase(fd))
-    throw FDMultiplexerException("Tried to remove unlisted fd "+std::to_string(fd)+ " from multiplexer");
-
   struct pollfd devent;
-  devent.fd=fd;
-  devent.events= POLLREMOVE;
+  devent.fd = fd;
+  devent.events = POLLREMOVE;
   devent.revents = 0;
 
-  if(write(d_devpollfd, &devent, sizeof(devent)) != sizeof(devent)) {
-    cbmap.erase(fd);
-    throw FDMultiplexerException("Removing fd from epoll set: "+stringerror());
+  if (write(d_devpollfd, &devent, sizeof(devent)) != sizeof(devent)) {
+    throw FDMultiplexerException("Removing fd from epoll set: " + stringerror());
   }
 }
 
@@ -119,20 +126,20 @@ void DevPollFDMultiplexer::getAvailableFDs(std::vector<int>& fds, int timeout)
   dvp.dp_nfds = d_readCallbacks.size() + d_writeCallbacks.size();
   dvp.dp_fds = pollfds.data();
   dvp.dp_timeout = timeout;
-  int ret=ioctl(d_devpollfd, DP_POLL, &dvp);
+  int ret = ioctl(d_devpollfd, DP_POLL, &dvp);
 
-  if(ret < 0 && errno!=EINTR) {
-    throw FDMultiplexerException("/dev/poll returned error: "+stringerror());
+  if (ret < 0 && errno != EINTR) {
+    throw FDMultiplexerException("/dev/poll returned error: " + stringerror());
   }
 
-  for(int n=0; n < ret; ++n) {
+  for (int n = 0; n < ret; ++n) {
     fds.push_back(pollfds.at(n).fd);
   }
 }
 
 int DevPollFDMultiplexer::run(struct timeval* now, int timeout)
 {
-  if(d_inrun) {
+  if (d_inrun) {
     throw FDMultiplexerException("FDMultiplexer::run() is not reentrant!\n");
   }
   std::vector<struct pollfd> fds(d_readCallbacks.size() + d_writeCallbacks.size());
@@ -140,34 +147,36 @@ int DevPollFDMultiplexer::run(struct timeval* now, int timeout)
   dvp.dp_nfds = d_readCallbacks.size() + d_writeCallbacks.size();
   dvp.dp_fds = fds.data();
   dvp.dp_timeout = timeout;
-  int ret=ioctl(d_devpollfd, DP_POLL, &dvp);
+  int ret = ioctl(d_devpollfd, DP_POLL, &dvp);
   int err = errno;
-  gettimeofday(now,0); // MANDATORY!
+  gettimeofday(now, nullptr); // MANDATORY!
 
-  if(ret < 0 && err!=EINTR) {
-    throw FDMultiplexerException("/dev/poll returned error: "+stringerror(err));
+  if (ret < 0 && err != EINTR) {
+    throw FDMultiplexerException("/dev/poll returned error: " + stringerror(err));
   }
 
-  if(ret < 1) { // thanks AB!
+  if (ret < 1) { // thanks AB!
     return 0;
   }
 
-  d_inrun=true;
-  for(int n=0; n < ret; ++n) {
-    d_iter=d_readCallbacks.find(fds.at(n).fd);
-    
-    if(d_iter != d_readCallbacks.end()) {
-      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
-      continue; // so we don't refind ourselves as writable!
+  d_inrun = true;
+  for (int n = 0; n < ret; ++n) {
+    if ((fds.at(n).revents & POLLIN) || (fds.at(n).revents & POLLERR) || (fds.at(n).revents & POLLHUP)) {
+      const auto& iter = d_readCallbacks.find(fds.at(n).fd);
+      if (iter != d_readCallbacks.end()) {
+        iter->d_callback(iter->d_fd, iter->d_parameter);
+      }
     }
-    d_iter=d_writeCallbacks.find(fds.at(n).fd);
-    
-    if(d_iter != d_writeCallbacks.end()) {
-      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
+
+    if ((fds.at(n).revents & POLLOUT) || (fds.at(n).revents & POLLERR)) {
+      const auto& iter = d_writeCallbacks.find(fds.at(n).fd);
+      if (iter != d_writeCallbacks.end()) {
+        iter->d_callback(iter->d_fd, iter->d_parameter);
+      }
     }
   }
 
-  d_inrun=false;
+  d_inrun = false;
   return ret;
 }
 
@@ -186,7 +195,7 @@ void acceptData(int fd, funcparam_t& parameter)
 int main()
 {
   Socket s(AF_INET, SOCK_DGRAM);
-  
+
   IPEndpoint loc("0.0.0.0", 2000);
   s.bind(loc);
 
@@ -201,5 +210,3 @@ int main()
   sfm.removeReadFD(s.getHandle());
 }
 #endif
-
-
index 7007c5c88cb519e35c6758ce4cbb6dd5ef8af200..edcec82b84335e56e17feda31a2481f726450f17 100644 (file)
@@ -341,20 +341,16 @@ public:
   {
   }
 
-  void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr) override
+  void addFD(int fd, FDMultiplexer::EventKind kind) override
   {
-    accountingAddFD(cbmap, fd, toDo, parameter, ttd);
   }
 
-  void removeFD(callbackmap_t& cbmap, int fd) override
+  void removeFD(int fd, FDMultiplexer::EventKind) override
   {
-    accountingRemoveFD(cbmap, fd);
   }
 
-  void alterFD(callbackmap_t& from, callbackmap_t& to, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd) override
+  void alterFD(int fd, FDMultiplexer::EventKind kind) override
   {
-    accountingRemoveFD(from, fd);
-    accountingAddFD(to, fd, toDo, parameter, ttd);
   }
 
   string getName() const override
index 7fa63e6237ba21b43fa693ca9cd2fad43f483a30..e80f1fc06632f28de9c1268b95361e489748f9b9 100644 (file)
@@ -37,29 +37,31 @@ class EpollFDMultiplexer : public FDMultiplexer
 {
 public:
   EpollFDMultiplexer();
-  virtual ~EpollFDMultiplexer()
+  ~EpollFDMultiplexer()
   {
-    close(d_epollfd);
+    if (d_epollfd >= 0) {
+      close(d_epollfd);
+    }
   }
 
-  virtual int run(struct timeval* tv, int timeout=500) override;
-  virtual void getAvailableFDs(std::vector<int>& fds, int timeout) override;
+  int run(struct timeval* tv, int timeout = 500) override;
+  void getAvailableFDs(std::vector<int>& fds, int timeout) override;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr) override;
-  virtual void removeFD(callbackmap_t& cbmap, int fd) override;
-  virtual void alterFD(callbackmap_t& from, callbackmap_t& to, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd) override;
+  void addFD(int fd, FDMultiplexer::EventKind kind) override;
+  void removeFD(int fd, FDMultiplexer::EventKind kind) override;
+  void alterFD(int fd, FDMultiplexer::EventKind kind) override;
 
   string getName() const override
   {
     return "epoll";
   }
+
 private:
   int d_epollfd;
   boost::shared_array<epoll_event> d_eevents;
   static int s_maxevents; // not a hard maximum
 };
 
-
 static FDMultiplexer* makeEpoll()
 {
   return new EpollFDMultiplexer();
@@ -67,122 +69,137 @@ static FDMultiplexer* makeEpoll()
 
 static struct EpollRegisterOurselves
 {
-  EpollRegisterOurselves() {
+  EpollRegisterOurselves()
+  {
     FDMultiplexer::getMultiplexerMap().insert(make_pair(0, &makeEpoll)); // priority 0!
   }
 } doItEpoll;
 
-int EpollFDMultiplexer::s_maxevents=1024;
+int EpollFDMultiplexer::s_maxevents = 1024;
 
-EpollFDMultiplexer::EpollFDMultiplexer() : d_eevents(new epoll_event[s_maxevents])
+EpollFDMultiplexer::EpollFDMultiplexer() :
+  d_eevents(new epoll_event[s_maxevents])
 {
-  d_epollfd=epoll_create(s_maxevents); // not hard max
-  if(d_epollfd < 0)
-    throw FDMultiplexerException("Setting up epoll: "+stringerror());
-  int fd=socket(AF_INET, SOCK_DGRAM, 0); // for self-test
-  if(fd < 0)
+  d_epollfd = epoll_create(s_maxevents); // not hard max
+  if (d_epollfd < 0) {
+    throw FDMultiplexerException("Setting up epoll: " + stringerror());
+  }
+  int fd = socket(AF_INET, SOCK_DGRAM, 0); // for self-test
+
+  if (fd < 0) {
     return;
+  }
+
   try {
     addReadFD(fd, 0);
     removeReadFD(fd);
     close(fd);
     return;
   }
-  catch(FDMultiplexerException &fe) {
+  catch (const FDMultiplexerException& fe) {
     close(fd);
     close(d_epollfd);
-    throw FDMultiplexerException("epoll multiplexer failed self-test: "+string(fe.what()));
+    throw FDMultiplexerException("epoll multiplexer failed self-test: " + string(fe.what()));
   }
-
 }
 
-void EpollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd)
+static uint32_t convertEventKind(FDMultiplexer::EventKind kind)
 {
-  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
+  switch (kind) {
+  case FDMultiplexer::EventKind::Read:
+    return EPOLLIN;
+  case FDMultiplexer::EventKind::Write:
+    return EPOLLOUT;
+  case FDMultiplexer::EventKind::Both:
+    return EPOLLIN | EPOLLOUT;
+  }
+}
 
+void EpollFDMultiplexer::addFD(int fd, FDMultiplexer::EventKind kind)
+{
   struct epoll_event eevent;
 
-  eevent.events = (&cbmap == &d_readCallbacks) ? EPOLLIN : EPOLLOUT;
+  eevent.events = convertEventKind(kind);
 
-  eevent.data.u64=0; // placate valgrind (I love it so much)
-  eevent.data.fd=fd;
+  eevent.data.u64 = 0; // placate valgrind (I love it so much)
+  eevent.data.fd = fd;
 
   if (epoll_ctl(d_epollfd, EPOLL_CTL_ADD, fd, &eevent) < 0) {
-    cbmap.erase(fd);
-    throw FDMultiplexerException("Adding fd to epoll set: "+stringerror());
+    throw FDMultiplexerException("Adding fd to epoll set: " + stringerror());
   }
 }
 
-void EpollFDMultiplexer::removeFD(callbackmap_t& cbmap, int fd)
+void EpollFDMultiplexer::removeFD(int fd, FDMultiplexer::EventKind)
 {
-  accountingRemoveFD(cbmap, fd);
-
   struct epoll_event dummy;
   dummy.events = 0;
   dummy.data.u64 = 0;
 
-  if(epoll_ctl(d_epollfd, EPOLL_CTL_DEL, fd, &dummy) < 0)
-    throw FDMultiplexerException("Removing fd from epoll set: "+stringerror());
+  if (epoll_ctl(d_epollfd, EPOLL_CTL_DEL, fd, &dummy) < 0) {
+    throw FDMultiplexerException("Removing fd from epoll set: " + stringerror());
+  }
 }
 
-void EpollFDMultiplexer::alterFD(callbackmap_t& from, callbackmap_t& to, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd)
+void EpollFDMultiplexer::alterFD(int fd, FDMultiplexer::EventKind kind)
 {
-  accountingRemoveFD(from, fd);
-  accountingAddFD(to, fd, toDo, parameter, ttd);
-
   struct epoll_event eevent;
-  eevent.events = (&to == &d_readCallbacks) ? EPOLLIN : EPOLLOUT;
+  eevent.events = convertEventKind(kind);
   eevent.data.u64 = 0; // placate valgrind (I love it so much)
   eevent.data.fd = fd;
 
   if (epoll_ctl(d_epollfd, EPOLL_CTL_MOD, fd, &eevent) < 0) {
-    to.erase(fd);
-    throw FDMultiplexerException("Altering fd in epoll set: "+stringerror());
+    throw FDMultiplexerException("Altering fd in epoll set: " + stringerror());
   }
 }
 
 void EpollFDMultiplexer::getAvailableFDs(std::vector<int>& fds, int timeout)
 {
-  int ret=epoll_wait(d_epollfd, d_eevents.get(), s_maxevents, timeout);
+  int ret = epoll_wait(d_epollfd, d_eevents.get(), s_maxevents, timeout);
 
-  if(ret < 0 && errno!=EINTR)
-    throw FDMultiplexerException("epoll returned error: "+stringerror());
+  if (ret < 0 && errno != EINTR) {
+    throw FDMultiplexerException("epoll returned error: " + stringerror());
+  }
 
-  for(int n=0; n < ret; ++n) {
+  for (int n = 0; n < ret; ++n) {
     fds.push_back(d_eevents[n].data.fd);
   }
 }
 
 int EpollFDMultiplexer::run(struct timeval* now, int timeout)
 {
-  if(d_inrun) {
+  if (d_inrun) {
     throw FDMultiplexerException("FDMultiplexer::run() is not reentrant!\n");
   }
 
-  int ret=epoll_wait(d_epollfd, d_eevents.get(), s_maxevents, timeout);
-  gettimeofday(now,0); // MANDATORY
+  int ret = epoll_wait(d_epollfd, d_eevents.get(), s_maxevents, timeout);
+  gettimeofday(now, nullptr); // MANDATORY
 
-  if(ret < 0 && errno!=EINTR)
-    throw FDMultiplexerException("epoll returned error: "+stringerror());
+  if (ret < 0 && errno != EINTR) {
+    throw FDMultiplexerException("epoll returned error: " + stringerror());
+  }
 
-  if(ret < 1) // thanks AB!
+  if (ret < 1) { // thanks AB!
     return 0;
+  }
 
-  d_inrun=true;
-  for(int n=0; n < ret; ++n) {
-    d_iter=d_readCallbacks.find(d_eevents[n].data.fd);
-
-    if(d_iter != d_readCallbacks.end()) {
-      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
-      continue; // so we don't refind ourselves as writable!
+  d_inrun = true;
+  for (int n = 0; n < ret; ++n) {
+    if ((d_eevents[n].events & EPOLLIN) || (d_eevents[n].events & EPOLLERR) || (d_eevents[n].events & EPOLLHUP)) {
+      const auto& iter = d_readCallbacks.find(d_eevents[n].data.fd);
+      if (iter != d_readCallbacks.end()) {
+        iter->d_callback(iter->d_fd, iter->d_parameter);
+      }
     }
-    d_iter=d_writeCallbacks.find(d_eevents[n].data.fd);
 
-    if(d_iter != d_writeCallbacks.end()) {
-      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
+    if ((d_eevents[n].events & EPOLLOUT) || (d_eevents[n].events & EPOLLERR) || (d_eevents[n].events & EPOLLHUP)) {
+      const auto& iter = d_writeCallbacks.find(d_eevents[n].data.fd);
+      if (iter != d_writeCallbacks.end()) {
+        iter->d_callback(iter->d_fd, iter->d_parameter);
+      }
     }
   }
-  d_inrun=false;
+
+  d_inrun = false;
   return ret;
 }
 
index cb8a3efb4e7a4e5c7acc700eb90f922f7ec29444..72511589760d079cf2e4cb6a02b49485358d9056 100644 (file)
@@ -39,27 +39,31 @@ class KqueueFDMultiplexer : public FDMultiplexer
 {
 public:
   KqueueFDMultiplexer();
-  virtual ~KqueueFDMultiplexer()
+  ~KqueueFDMultiplexer()
   {
-    close(d_kqueuefd);
+    if (d_kqueuefd >= 0) {
+      close(d_kqueuefd);
+    }
   }
 
-  virtual int run(struct timeval* tv, int timeout=500) override;
-  virtual void getAvailableFDs(std::vector<int>& fds, int timeout) override;
+  int run(struct timeval* tv, int timeout = 500) override;
+  void getAvailableFDs(std::vector<int>& fds, int timeout) override;
+
+  void addFD(int fd, FDMultiplexer::EventKind kind) override;
+  void removeFD(int fd) override;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd=nullptr) override;
-  virtual void removeFD(callbackmap_t& cbmap, int fd) override;
   string getName() const override
   {
     return "kqueue";
   }
+
 private:
   int d_kqueuefd;
   boost::shared_array<struct kevent> d_kevents;
   static unsigned int s_maxevents; // not a hard maximum
 };
 
-unsigned int KqueueFDMultiplexer::s_maxevents=1024;
+unsigned int KqueueFDMultiplexer::s_maxevents = 1024;
 
 static FDMultiplexer* make()
 {
@@ -68,94 +72,112 @@ static FDMultiplexer* make()
 
 static struct KqueueRegisterOurselves
 {
-  KqueueRegisterOurselves() {
+  KqueueRegisterOurselves()
+  {
     FDMultiplexer::getMultiplexerMap().insert(make_pair(0, &make)); // priority 0!
   }
 } kQueueDoIt;
 
-KqueueFDMultiplexer::KqueueFDMultiplexer() : d_kevents(new struct kevent[s_maxevents])
+KqueueFDMultiplexer::KqueueFDMultiplexer() :
+  d_kevents(new struct kevent[s_maxevents])
 {
-  d_kqueuefd=kqueue();
-  if(d_kqueuefd < 0)
-    throw FDMultiplexerException("Setting up kqueue: "+stringerror());
+  d_kqueuefd = kqueue();
+  if (d_kqueuefd < 0) {
+    throw FDMultiplexerException("Setting up kqueue: " + stringerror());
+  }
 }
 
-void KqueueFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd)
+static uint32_t convertEventKind(FDMultiplexer::EventKind kind)
 {
-  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
+  switch (kind) {
+  case FDMultiplexer::EventKind::Read:
+    return EVFILT_READ;
+  case FDMultiplexer::EventKind::Write:
+    return EVFILT_WRITE;
+  case FDMultiplexer::EventKind::Both:
+    return EVFILT_READ | EVFILT_WRITE;
+  }
+}
 
+void KqueueFDMultiplexer::addFD(int fd, FDMultiplexer::EventKind kind)
+{
   struct kevent kqevent;
-  EV_SET(&kqevent, fd, (&cbmap == &d_readCallbacks) ? EVFILT_READ : EVFILT_WRITE, EV_ADD, 0,0,0);
+  EV_SET(&kqevent, fd, convertEventKind(kind), EV_ADD, 0, 0, 0);
 
-  if(kevent(d_kqueuefd, &kqevent, 1, 0, 0, 0) < 0) {
-    cbmap.erase(fd);
-    throw FDMultiplexerException("Adding fd to kqueue set: "+stringerror());
+  if (kevent(d_kqueuefd, &kqevent, 1, 0, 0, 0) < 0) {
+    throw FDMultiplexerException("Adding fd to kqueue set: " + stringerror());
   }
 }
 
-void KqueueFDMultiplexer::removeFD(callbackmap_t& cbmap, int fd)
+void KqueueFDMultiplexer::removeFD(int fd, FDMultiplexer::EventKind kind)
 {
-  accountingRemoveFD(cbmap, fd);
-
   struct kevent kqevent;
-  EV_SET(&kqevent, fd, (&cbmap == &d_readCallbacks) ? EVFILT_READ : EVFILT_WRITE, EV_DELETE, 0,0,0);
-  
-  if(kevent(d_kqueuefd, &kqevent, 1, 0, 0, 0) < 0) // ponder putting Callback back on the map..
-    throw FDMultiplexerException("Removing fd from kqueue set: "+stringerror());
+  EV_SET(&kqevent, fd, convertEventKind(kind), EV_DELETE, 0, 0, 0);
+
+  if (kevent(d_kqueuefd, &kqevent, 1, 0, 0, 0) < 0) {
+    // ponder putting Callback back on the map..
+    throw FDMultiplexerException("Removing fd from kqueue set: " + stringerror());
+  }
 }
 
 void KqueueFDMultiplexer::getAvailableFDs(std::vector<int>& fds, int timeout)
 {
   struct timespec ts;
-  ts.tv_sec=timeout/1000;
-  ts.tv_nsec=(timeout % 1000) * 1000000;
+  ts.tv_sec = timeout / 1000;
+  ts.tv_nsec = (timeout % 1000) * 1000000;
 
   int ret = kevent(d_kqueuefd, 0, 0, d_kevents.get(), s_maxevents, &ts);
 
-  if(ret < 0 && errno != EINTR)
-    throw FDMultiplexerException("kqueue returned error: "+stringerror());
+  if (ret < 0 && errno != EINTR) {
+    throw FDMultiplexerException("kqueue returned error: " + stringerror());
+  }
 
-  for(int n=0; n < ret; ++n) {
+  for (int n = 0; n < ret; ++n) {
     fds.push_back(d_kevents[n].ident);
   }
 }
 
 int KqueueFDMultiplexer::run(struct timeval* now, int timeout)
 {
-  if(d_inrun) {
+  if (d_inrun) {
     throw FDMultiplexerException("FDMultiplexer::run() is not reentrant!\n");
   }
-  
+
   struct timespec ts;
-  ts.tv_sec=timeout/1000;
-  ts.tv_nsec=(timeout % 1000) * 1000000;
+  ts.tv_sec = timeout / 1000;
+  ts.tv_nsec = (timeout % 1000) * 1000000;
 
-  int ret=kevent(d_kqueuefd, 0, 0, d_kevents.get(), s_maxevents, &ts);
-  gettimeofday(now,0); // MANDATORY!
+  int ret = kevent(d_kqueuefd, 0, 0, d_kevents.get(), s_maxevents, &ts);
+  gettimeofday(now, nullptr); // MANDATORY!
 
-  if(ret < 0 && errno!=EINTR)
-    throw FDMultiplexerException("kqueue returned error: "+stringerror());
+  if (ret < 0 && errno != EINTR) {
+    throw FDMultiplexerException("kqueue returned error: " + stringerror());
+  }
 
-  if(ret < 0) // nothing - thanks AB!
+  if (ret < 0) {
+    // nothing - thanks AB!
     return 0;
+  }
 
-  d_inrun=true;
+  d_inrun = true;
 
-  for(int n=0; n < ret; ++n) {
-    d_iter=d_readCallbacks.find(d_kevents[n].ident);
-    if(d_iter != d_readCallbacks.end()) {
-      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
-      continue; // so we don't find ourselves as writable again
+  for (int n = 0; n < ret; ++n) {
+    if (d_kevents[n].filter == EVFILT_READ) {
+      const auto& iter = d_readCallbacks.find(d_kevents[n].ident);
+      if (iter != d_readCallbacks.end()) {
+        iter->d_callback(iter->d_fd, iter->d_parameter);
+      }
     }
 
-    d_iter=d_writeCallbacks.find(d_kevents[n].ident);
-
-    if(d_iter != d_writeCallbacks.end()) {
-      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
+    if (d_kevents[n].filter == EVFILT_WRITE) {
+      const auto& iter = d_writeCallbacks.find(d_kevents[n].ident);
+      if (iter != d_writeCallbacks.end()) {
+        iter->d_callback(iter->d_fd, iter->d_parameter);
+      }
     }
   }
 
-  d_inrun=false;
+  d_inrun = false;
   return ret;
 }
 
@@ -173,7 +195,7 @@ void acceptData(int fd, boost::any& parameter)
 int main()
 {
   Socket s(AF_INET, SOCK_DGRAM);
-  
+
   IPEndpoint loc("0.0.0.0", 2000);
   s.bind(loc);
 
@@ -188,6 +210,3 @@ int main()
   sfm.removeReadFD(s.getHandle());
 }
 #endif
-
-
-
index 9408a2ad3748a9ac7555de327c95a5351e7a9319..d1c573a1efcc812a0a901875e81fa12ba58da2dc 100644 (file)
@@ -40,11 +40,11 @@ using namespace ::boost::multi_index;
 class FDMultiplexerException : public std::runtime_error
 {
 public:
-  FDMultiplexerException(const std::string& str) : std::runtime_error(str)
+  FDMultiplexerException(const std::string& str) :
+    std::runtime_error(str)
   {}
 };
 
-
 /** Very simple FD multiplexer, based on callbacks and boost::any parameters
     As a special service, this parameter is kept around and can be modified, 
     allowing for state to be stored inside the multiplexer.
@@ -56,9 +56,15 @@ class FDMultiplexer
 {
 public:
   typedef boost::any funcparam_t;
-  typedef boost::function< void(int, funcparam_t&) > callbackfunc_t;
-protected:
+  typedef boost::function<void(int, funcparam_t&)> callbackfunc_t;
+  enum class EventKind : uint8_t
+  {
+    Read,
+    Write,
+    Both
+  };
 
+protected:
   struct Callback
   {
     callbackfunc_t d_callback;
@@ -68,49 +74,86 @@ protected:
   };
 
 public:
-  FDMultiplexer() : d_inrun(false)
+  FDMultiplexer() :
+    d_inrun(false)
   {}
   virtual ~FDMultiplexer()
   {}
 
   static FDMultiplexer* getMultiplexerSilent();
-  
+
   /* tv will be updated to 'now' before run returns */
   /* timeout is in ms */
   /* returns 0 on timeout, -1 in case of error (but all implementations
      actually throw in that case) and the number of ready events otherwise */
-  virtual int run(struct timeval* tv, int timeout=500) = 0;
+  virtual int run(struct timeval* tv, int timeout = 500) = 0;
 
   /* timeout is in ms, 0 will return immediately, -1 will block until at least one FD is ready */
   virtual void getAvailableFDs(std::vector<int>& fds, int timeout) = 0;
 
   //! Add an fd to the read watch list - currently an fd can only be on one list at a time!
-  virtual void addReadFD(int fd, callbackfunc_t toDo, const funcparam_t& parameter=funcparam_t(), const struct timeval* ttd=nullptr)
+  void addReadFD(int fd, callbackfunc_t toDo, const funcparam_t& parameter = funcparam_t(), const struct timeval* ttd = nullptr)
   {
-    this->addFD(d_readCallbacks, fd, toDo, parameter, ttd);
+    bool alreadyWatched = d_writeCallbacks.count(fd) > 0;
+
+    if (alreadyWatched) {
+      this->alterFD(fd, EventKind::Both);
+    }
+    else {
+      this->addFD(fd, EventKind::Read);
+    }
+
+    /* do the addition _after_ so the entry is not added if there is an error */
+    accountingAddFD(d_readCallbacks, fd, toDo, parameter, ttd);
   }
 
   //! Add an fd to the write watch list - currently an fd can only be on one list at a time!
-  virtual void addWriteFD(int fd, callbackfunc_t toDo, const funcparam_t& parameter=funcparam_t(), const struct timeval* ttd=nullptr)
+  void addWriteFD(int fd, callbackfunc_t toDo, const funcparam_t& parameter = funcparam_t(), const struct timeval* ttd = nullptr)
   {
-    this->addFD(d_writeCallbacks, fd, toDo, parameter, ttd);
+    bool alreadyWatched = d_readCallbacks.count(fd) > 0;
+
+    if (alreadyWatched) {
+      this->alterFD(fd, EventKind::Both);
+    }
+    else {
+      this->addFD(fd, EventKind::Write);
+    }
+
+    /* do the addition _after_ so the entry is not added if there is an error */
+    accountingAddFD(d_writeCallbacks, fd, toDo, parameter, ttd);
   }
 
   //! Remove an fd from the read watch list. You can't call this function on an fd that is closed already!
   /** WARNING: references to 'parameter' become invalid after this function! */
-  virtual void removeReadFD(int fd)
+  void removeReadFD(int fd)
   {
-    this->removeFD(d_readCallbacks, fd);
+    const auto& iter = d_writeCallbacks.find(fd);
+    accountingRemoveFD(d_readCallbacks, fd);
+
+    if (iter != d_writeCallbacks.end()) {
+      this->alterFD(fd, EventKind::Write);
+    }
+    else {
+      this->removeFD(fd, EventKind::Read);
+    }
   }
 
   //! Remove an fd from the write watch list. You can't call this function on an fd that is closed already!
   /** WARNING: references to 'parameter' become invalid after this function! */
-  virtual void removeWriteFD(int fd)
+  void removeWriteFD(int fd)
   {
-    this->removeFD(d_writeCallbacks, fd);
+    const auto& iter = d_readCallbacks.find(fd);
+    accountingRemoveFD(d_writeCallbacks, fd);
+
+    if (iter != d_readCallbacks.end()) {
+      this->alterFD(fd, EventKind::Read);
+    }
+    else {
+      this->removeFD(fd, EventKind::Write);
+    }
   }
 
-  virtual void setReadTTD(int fd, struct timeval tv, int timeout)
+  void setReadTTD(int fd, struct timeval tv, int timeout)
   {
     const auto& it = d_readCallbacks.find(fd);
     if (it == d_readCallbacks.end()) {
@@ -123,7 +166,7 @@ public:
     d_readCallbacks.replace(it, newEntry);
   }
 
-  virtual void setWriteTTD(int fd, struct timeval tv, int timeout)
+  void setWriteTTD(int fd, struct timeval tv, int timeout)
   {
     const auto& it = d_writeCallbacks.find(fd);
     if (it == d_writeCallbacks.end()) {
@@ -136,19 +179,23 @@ public:
     d_writeCallbacks.replace(it, newEntry);
   }
 
-  virtual void alterFDToRead(int fd, callbackfunc_t toDo, const funcparam_t& parameter=funcparam_t(), const struct timeval* ttd=nullptr)
+  void alterFDToRead(int fd, callbackfunc_t toDo, const funcparam_t& parameter = funcparam_t(), const struct timeval* ttd = nullptr)
   {
-    this->alterFD(d_writeCallbacks, d_readCallbacks, fd, toDo, parameter, ttd);
+    accountingRemoveFD(d_writeCallbacks, fd);
+    this->alterFD(fd, EventKind::Read);
+    accountingAddFD(d_readCallbacks, fd, toDo, parameter, ttd);
   }
 
-  virtual void alterFDToWrite(int fd, callbackfunc_t toDo, const funcparam_t& parameter=funcparam_t(), const struct timeval* ttd=nullptr)
+  void alterFDToWrite(int fd, callbackfunc_t toDo, const funcparam_t& parameter = funcparam_t(), const struct timeval* ttd = nullptr)
   {
-    this->alterFD(d_readCallbacks, d_writeCallbacks, fd, toDo, parameter, ttd);
+    accountingRemoveFD(d_readCallbacks, fd);
+    this->alterFD(fd, EventKind::Write);
+    accountingAddFD(d_writeCallbacks, fd, toDo, parameter, ttd);
   }
 
-  virtual std::vector<std::pair<int, funcparam_t> > getTimeouts(const struct timeval& tv, bool writes=false)
+  std::vector<std::pair<int, funcparam_t>> getTimeouts(const struct timeval& tv, bool writes = false)
   {
-    std::vector<std::pair<int, funcparam_t> > ret;
+    std::vector<std::pair<int, funcparam_t>> ret;
     const auto tied = boost::tie(tv.tv_sec, tv.tv_usec);
     auto& idx = writes ? d_writeCallbacks.get<TTDOrderedTag>() : d_readCallbacks.get<TTDOrderedTag>();
 
@@ -170,7 +217,7 @@ public:
     static FDMultiplexermap_t theMap;
     return theMap;
   }
-  
+
   virtual std::string getName() const = 0;
 
   size_t getWatchedFDCount(bool writeFDs) const
@@ -178,7 +225,7 @@ public:
     return writeFDs ? d_writeCallbacks.size() : d_readCallbacks.size();
   }
 
-  void runForAllWatchedFDs(void(*watcher)(bool isRead, int fd, const funcparam_t&, struct timeval))
+  void runForAllWatchedFDs(void (*watcher)(bool isRead, int fd, const funcparam_t&, struct timeval))
   {
     for (const auto& entry : d_readCallbacks) {
       watcher(true, entry.d_fd, entry.d_parameter, entry.d_ttd);
@@ -189,12 +236,16 @@ public:
   }
 
 protected:
-  struct FDBasedTag {};
-  struct TTDOrderedTag {};
+  struct FDBasedTag
+  {
+  };
+  struct TTDOrderedTag
+  {
+  };
   struct ttd_compare
   {
     /* we want a 0 TTD (no timeout) to come _after_ everything else */
-    bool operator() (const struct timeval& lhs, const struct timeval& rhs) const
+    bool operator()(const struct timeval& lhs, const struct timeval& rhs) const
     {
       /* special treatment if at least one of the TTD is 0,
          normal comparison otherwise */
@@ -214,31 +265,23 @@ protected:
 
   typedef multi_index_container<
     Callback,
-    indexed_by <
-                hashed_unique<tag<FDBasedTag>,
-                              member<Callback,int,&Callback::d_fd>
-                              >,
-                ordered_non_unique<tag<TTDOrderedTag>,
-                                   member<Callback,struct timeval,&Callback::d_ttd>,
-                                   ttd_compare
-                                   >
-               >
-  > callbackmap_t;
+    indexed_by<
+      hashed_unique<tag<FDBasedTag>,
+                    member<Callback, int, &Callback::d_fd>>,
+      ordered_non_unique<tag<TTDOrderedTag>,
+                         member<Callback, struct timeval, &Callback::d_ttd>,
+                         ttd_compare>>>
+    callbackmap_t;
 
   callbackmap_t d_readCallbacks, d_writeCallbacks;
-
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr)=0;
-  virtual void removeFD(callbackmap_t& cbmap, int fd)=0;
-
   bool d_inrun;
-  callbackmap_t::iterator d_iter;
 
-  void accountingAddFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr)
+  void accountingAddFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd)
   {
     Callback cb;
     cb.d_fd = fd;
-    cb.d_callback=toDo;
-    cb.d_parameter=parameter;
+    cb.d_callback = toDo;
+    cb.d_parameter = parameter;
     memset(&cb.d_ttd, 0, sizeof(cb.d_ttd));
     if (ttd) {
       cb.d_ttd = *ttd;
@@ -246,22 +289,24 @@ protected:
 
     auto pair = cbmap.insert(std::move(cb));
     if (!pair.second) {
-      throw FDMultiplexerException("Tried to add fd "+std::to_string(fd)+ " to multiplexer twice");
+      throw FDMultiplexerException("Tried to add fd " + std::to_string(fd) + " to multiplexer twice");
     }
   }
 
-  void accountingRemoveFD(callbackmap_t& cbmap, int fd) 
+  void accountingRemoveFD(callbackmap_t& cbmap, int fd)
   {
-    if(!cbmap.erase(fd)) {
-      throw FDMultiplexerException("Tried to remove unlisted fd "+std::to_string(fd)+ " from multiplexer");
+    if (!cbmap.erase(fd)) {
+      throw FDMultiplexerException("Tried to remove unlisted fd " + std::to_string(fd) + " from multiplexer");
     }
   }
 
-  virtual void alterFD(callbackmap_t& from, callbackmap_t& to, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd)
+  virtual void addFD(int fd, EventKind kind) = 0;
+  /* most implementations do not care about which event has to be removed, except for kqueue */
+  virtual void removeFD(int fd, EventKind kind) = 0;
+  virtual void alterFD(int fd, EventKind kind)
   {
     /* naive implementation */
-    removeFD(from, fd);
-    addFD(to, fd, toDo, parameter, ttd);
+    removeFD(fd, (kind == EventKind::Write) ? EventKind::Read : EventKind::Write);
+    addFD(fd, kind);
   }
-
 };
index 665b4b823c8a9788a391fc8bc77ab58c978f5761..bd01f0b36d0d0a6b98964b6e3325b1ca5a7bde7a 100644 (file)
 FDMultiplexer* FDMultiplexer::getMultiplexerSilent()
 {
   FDMultiplexer* ret = nullptr;
-  for(const auto& i : FDMultiplexer::getMultiplexerMap()) {
+  for (const auto& i : FDMultiplexer::getMultiplexerMap()) {
     try {
       ret = i.second();
       return ret;
     }
-    catch(const FDMultiplexerException& fe) {
+    catch (const FDMultiplexerException& fe) {
     }
-    catch(...) {
+    catch (...) {
     }
   }
   return ret;
 }
 
-
 class PollFDMultiplexer : public FDMultiplexer
 {
 public:
   PollFDMultiplexer()
   {}
-  virtual ~PollFDMultiplexer()
+  ~PollFDMultiplexer()
   {
   }
 
-  virtual int run(struct timeval* tv, int timeout=500) override;
-  virtual void getAvailableFDs(std::vector<int>& fds, int timeout) override;
+  int run(struct timeval* tv, int timeout = 500) override;
+  void getAvailableFDs(std::vector<int>& fds, int timeout) override;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr) override;
-  virtual void removeFD(callbackmap_t& cbmap, int fd) override;
+  void addFD(int fd, FDMultiplexer::EventKind) override;
+  void removeFD(int fd, FDMultiplexer::EventKind) override;
 
   string getName() const override
   {
     return "poll";
   }
+
 private:
   vector<struct pollfd> preparePollFD() const;
 };
@@ -55,55 +55,69 @@ static FDMultiplexer* make()
 
 static struct RegisterOurselves
 {
-  RegisterOurselves() {
+  RegisterOurselves()
+  {
     FDMultiplexer::getMultiplexerMap().insert(make_pair(1, &make));
   }
 } doIt;
 
-void PollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd)
+void PollFDMultiplexer::addFD(int fd, FDMultiplexer::EventKind kind)
 {
-  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
 }
 
-void PollFDMultiplexer::removeFD(callbackmap_t& cbmap, int fd)
+void PollFDMultiplexer::removeFD(int fd, FDMultiplexer::EventKind)
 {
-  if(d_inrun && d_iter->d_fd==fd)  // trying to remove us!
-    ++d_iter;
-
-  accountingRemoveFD(cbmap, fd);
 }
 
 vector<struct pollfd> PollFDMultiplexer::preparePollFD() const
 {
-  vector<struct pollfd> pollfds;
+  std::unordered_map<int, struct pollfd> pollfds;
   pollfds.reserve(d_readCallbacks.size() + d_writeCallbacks.size());
 
-  struct pollfd pollfd;
-  for(const auto& cb : d_readCallbacks) {
-    pollfd.fd = cb.d_fd;
-    pollfd.events = POLLIN;
-    pollfds.push_back(pollfd);
+  for (const auto& cb : d_readCallbacks) {
+    if (pollfds.count(cb.d_fd) == 0) {
+      auto& pollfd = pollfds[cb.d_fd];
+      pollfd.fd = cb.d_fd;
+      pollfd.events = 0;
+    }
+    auto& pollfd = pollfds.at(cb.d_fd);
+    pollfd.events |= POLLIN;
+  }
+
+  for (const auto& cb : d_writeCallbacks) {
+    if (pollfds.count(cb.d_fd) == 0) {
+      auto& pollfd = pollfds[cb.d_fd];
+      pollfd.fd = cb.d_fd;
+      pollfd.events = 0;
+    }
+    auto& pollfd = pollfds.at(cb.d_fd);
+    pollfd.events |= POLLOUT;
   }
 
-  for(const auto& cb : d_writeCallbacks) {
-    pollfd.fd = cb.d_fd;
-    pollfd.events = POLLOUT;
-    pollfds.push_back(pollfd);
+  std::vector<struct pollfd> result;
+  result.reserve(pollfds.size());
+  for (const auto& entry : pollfds) {
+    result.push_back(entry.second);
   }
 
-  return pollfds;
+  return result;
 }
 
 void PollFDMultiplexer::getAvailableFDs(std::vector<int>& fds, int timeout)
 {
   auto pollfds = preparePollFD();
+  if (pollfds.empty()) {
+    return;
+  }
+
   int ret = poll(&pollfds[0], pollfds.size(), timeout);
 
-  if (ret < 0 && errno != EINTR)
+  if (ret < 0 && errno != EINTR) {
     throw FDMultiplexerException("poll returned error: " + stringerror());
+  }
 
-  for(const auto& pollfd : pollfds) {
-    if (pollfd.revents & POLLIN || pollfd.revents & POLLOUT) {
+  for (const auto& pollfd : pollfds) {
+    if (pollfd.revents & POLLIN || pollfd.revents & POLLOUT || pollfd.revents & POLLERR || pollfd.revents & POLLHUP) {
       fds.push_back(pollfd.fd);
     }
   }
@@ -111,39 +125,43 @@ void PollFDMultiplexer::getAvailableFDs(std::vector<int>& fds, int timeout)
 
 int PollFDMultiplexer::run(struct timeval* now, int timeout)
 {
-  if(d_inrun) {
+  if (d_inrun) {
     throw FDMultiplexerException("FDMultiplexer::run() is not reentrant!\n");
   }
 
   auto pollfds = preparePollFD();
+  if (pollfds.empty()) {
+    gettimeofday(now, nullptr); // MANDATORY!
+    return 0;
+  }
 
-  int ret=poll(&pollfds[0], pollfds.size(), timeout);
-  gettimeofday(now, 0); // MANDATORY!
-  
-  if(ret < 0 && errno!=EINTR)
-    throw FDMultiplexerException("poll returned error: "+stringerror());
-
-  d_iter=d_readCallbacks.end();
-  d_inrun=true;
-
-  for(const auto& pollfd : pollfds) {
-    if(pollfd.revents & POLLIN) {
-      d_iter=d_readCallbacks.find(pollfd.fd);
-    
-      if(d_iter != d_readCallbacks.end()) {
-        d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
-        continue; // so we don't refind ourselves as writable!
+  int ret = poll(&pollfds[0], pollfds.size(), timeout);
+  gettimeofday(now, nullptr); // MANDATORY!
+
+  if (ret < 0 && errno != EINTR) {
+    throw FDMultiplexerException("poll returned error: " + stringerror());
+  }
+
+  d_inrun = true;
+
+  for (const auto& pollfd : pollfds) {
+
+    if (pollfd.revents & POLLIN || pollfd.revents & POLLERR || pollfd.revents & POLLHUP) {
+      const auto& iter = d_readCallbacks.find(pollfd.fd);
+      if (iter != d_readCallbacks.end()) {
+        iter->d_callback(iter->d_fd, iter->d_parameter);
       }
     }
-    else if(pollfd.revents & POLLOUT) {
-      d_iter=d_writeCallbacks.find(pollfd.fd);
-    
-      if(d_iter != d_writeCallbacks.end()) {
-        d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
+
+    if (pollfd.revents & POLLOUT || pollfd.revents & POLLERR) {
+      const auto& iter = d_writeCallbacks.find(pollfd.fd);
+      if (iter != d_writeCallbacks.end()) {
+        iter->d_callback(iter->d_fd, iter->d_parameter);
       }
     }
   }
-  d_inrun=false;
+
+  d_inrun = false;
   return ret;
 }
 
@@ -163,7 +181,7 @@ void acceptData(int fd, boost::any& parameter)
 int main()
 {
   Socket s(AF_INET, SOCK_DGRAM);
-  
+
   IPEndpoint loc("0.0.0.0", 2000);
   s.bind(loc);
 
@@ -178,4 +196,3 @@ int main()
   sfm.removeReadFD(s.getHandle());
 }
 #endif
-
index 39a0aa07ec931992f912f39979327df2e7366619..3ee0a37524c709127139b6b2c22aa0468bd400fe 100644 (file)
@@ -18,27 +18,29 @@ class PortsFDMultiplexer : public FDMultiplexer
 {
 public:
   PortsFDMultiplexer();
-  virtual ~PortsFDMultiplexer()
+  ~PortsFDMultiplexer()
   {
     close(d_portfd);
   }
 
-  virtual int run(struct timeval* tv, int timeout=500) override;
-  virtual void getAvailableFDs(std::vector<int>& fds, int timeout) override;
+  int run(struct timeval* tv, int timeout = 500) override;
+  void getAvailableFDs(std::vector<int>& fds, int timeout) override;
+
+  void addFD(int fd, FDMultiplexer::EventKind kind) override;
+  void removeFD(int fd, FDMultiplexer::EventKind kind) override;
+  void alterFD(int fd, FDMultiplexer::EventKind kind) override;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd=nullptr) override;
-  virtual void removeFD(callbackmap_t& cbmap, int fd) override;
   string getName() const override
   {
     return "solaris completion ports";
   }
+
 private:
   int d_portfd;
   boost::shared_array<port_event_t> d_pevents;
   static int s_maxevents; // not a hard maximum
 };
 
-
 static FDMultiplexer* makePorts()
 {
   return new PortsFDMultiplexer();
@@ -46,37 +48,47 @@ static FDMultiplexer* makePorts()
 
 static struct PortsRegisterOurselves
 {
-  PortsRegisterOurselves() {
+  PortsRegisterOurselves()
+  {
     FDMultiplexer::getMultiplexerMap().insert(make_pair(0, &makePorts)); // priority 0!
   }
 } doItPorts;
 
+int PortsFDMultiplexer::s_maxevents = 1024;
 
-int PortsFDMultiplexer::s_maxevents=1024;
-PortsFDMultiplexer::PortsFDMultiplexer() : d_pevents(new port_event_t[s_maxevents])
+PortsFDMultiplexer::PortsFDMultiplexer() :
+  d_pevents(new port_event_t[s_maxevents])
 {
-  d_portfd=port_create(); // not hard max
-  if(d_portfd < 0)
-    throw FDMultiplexerException("Setting up port: "+stringerror());
+  d_portfd = port_create(); // not hard max
+  if (d_portfd < 0) {
+    throw FDMultiplexerException("Setting up port: " + stringerror());
+  }
 }
 
-void PortsFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd)
+static int convertEventKind(FDMultiplexer::EventKind kind)
 {
-  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
-
-  if(port_associate(d_portfd, PORT_SOURCE_FD, fd, (&cbmap == &d_readCallbacks) ? POLLIN : POLLOUT, 0) < 0) {
-    cbmap.erase(fd);
-    throw FDMultiplexerException("Adding fd to port set: "+stringerror());
+  switch (kind) {
+  case FDMultiplexer::EventKind::Read:
+    return POLLIN;
+  case FDMultiplexer::EventKind::Write:
+    return POLLOUT;
+  case FDMultiplexer::EventKind::Both:
+    return POLLIN | POLLOUT;
   }
 }
 
-void PortsFDMultiplexer::removeFD(callbackmap_t& cbmap, int fd)
+void PortsFDMultiplexer::addFD(int fd, FDMultiplexer::EventKind kind)
 {
-  if(!cbmap.erase(fd))
-    throw FDMultiplexerException("Tried to remove unlisted fd "+std::to_string(fd)+ " from multiplexer");
+  if (port_associate(d_portfd, PORT_SOURCE_FD, fd, convertEventKind(kind), 0) < 0) {
+    throw FDMultiplexerException("Adding fd to port set: " + stringerror());
+  }
+}
 
-  if(port_dissociate(d_portfd, PORT_SOURCE_FD, fd) < 0 && errno != ENOENT) // it appears under some circumstances, ENOENT will be returned, without this being an error. Apache has this same "fix"
-    throw FDMultiplexerException("Removing fd from port set: "+stringerror());
+void PortsFDMultiplexer::removeFD(int fd, FDMultiplexer::EventKind)
+{
+  if (port_dissociate(d_portfd, PORT_SOURCE_FD, fd) < 0 && errno != ENOENT) { // it appears under some circumstances, ENOENT will be returned, without this being an error. Apache has this same "fix"
+    throw FDMultiplexerException("Removing fd from port set: " + stringerror());
+  }
 }
 
 void PortsFDMultiplexer::getAvailableFDs(std::vector<int>& fds, int timeout)
@@ -113,16 +125,21 @@ void PortsFDMultiplexer::getAvailableFDs(std::vector<int>& fds, int timeout)
     const auto fd = d_pevents[n].portev_object;
 
     /* we need to re-associate the FD */
-    if (d_readCallbacks.count(fd)) {
-      if (port_associate(d_portfd, PORT_SOURCE_FD, fd, POLLIN, 0) < 0) {
-        throw FDMultiplexerException("Unable to add fd back to ports (read): " + stringerror());
+    if ((d_pevents[n].portev_events & POLLIN || d_pevents[n].portev_events & POLLER || d_pevents[n].portev_events & POLLHUP)) {
+      if (d_readCallbacks.count(fd)) {
+        if (port_associate(d_portfd, PORT_SOURCE_FD, fd, d_writeCallbacks.count(fd) > 0 ? POLLIN | POLLOUT : POLLIN, 0) < 0) {
+          throw FDMultiplexerException("Unable to add fd back to ports (read): " + stringerror());
+        }
       }
     }
-    else if (d_writeCallbacks.count(fd)) {
-      if (port_associate(d_portfd, PORT_SOURCE_FD, fd, POLLOUT, 0) < 0) {
-        throw FDMultiplexerException("Unable to add fd back to ports (write): " + stringerror());
+    else if ((d_pevents[n].portev_events & POLLOUT || d_pevents[n].portev_events & POLLER)) {
+      if (d_writeCallbacks.count(fd)) {
+        if (port_associate(d_portfd, PORT_SOURCE_FD, fd, d_readCallbacks.count(fd) > 0 ? POLLIN | POLLOUT : POLLOUT, 0) < 0) {
+          throw FDMultiplexerException("Unable to add fd back to ports (write): " + stringerror());
+        }
       }
-    } else {
+    }
+    else {
       /* not registered, this is unexpected */
       continue;
     }
@@ -133,58 +150,60 @@ void PortsFDMultiplexer::getAvailableFDs(std::vector<int>& fds, int timeout)
 
 int PortsFDMultiplexer::run(struct timeval* now, int timeout)
 {
-  if(d_inrun) {
+  if (d_inrun) {
     throw FDMultiplexerException("FDMultiplexer::run() is not reentrant!\n");
   }
-  
+
   struct timespec timeoutspec;
   timeoutspec.tv_sec = timeout / 1000;
   timeoutspec.tv_nsec = (timeout % 1000) * 1000000;
-  unsigned int numevents=1;
-  int ret= port_getn(d_portfd, d_pevents.get(), min(PORT_MAX_LIST, s_maxevents), &numevents, &timeoutspec);
-  
+  unsigned int numevents = 1;
+  int ret = port_getn(d_portfd, d_pevents.get(), min(PORT_MAX_LIST, s_maxevents), &numevents, &timeoutspec);
+
   /* port_getn has an unusual API - (ret == -1, errno == ETIME) can
      mean partial success; you must check (*numevents) in this case
      and process anything in there, otherwise you'll never see any
      events from that object again. We don't care about pure timeouts
      (ret == -1, errno == ETIME, *numevents == 0) so we don't bother
      with that case. */
-  if(ret == -1 && errno!=ETIME) {
-    if(errno!=EINTR)
-      throw FDMultiplexerException("completion port_getn returned error: "+stringerror());
+  if (ret == -1 && errno != ETIME) {
+    if (errno != EINTR) {
+      throw FDMultiplexerException("completion port_getn returned error: " + stringerror());
+    }
     // EINTR is not really an error
-    gettimeofday(now,0);
+    gettimeofday(now, nullptr);
     return 0;
   }
-  gettimeofday(now,0);
-  if(!numevents) // nothing
+  gettimeofday(now, nullptr);
+  if (!numevents) {
+    // nothing
     return 0;
+  }
 
-  d_inrun=true;
-
-  for(unsigned int n=0; n < numevents; ++n) {
-    d_iter=d_readCallbacks.find(d_pevents[n].portev_object);
-    
-    if(d_iter != d_readCallbacks.end()) {
-      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
-      if(d_readCallbacks.count(d_pevents[n].portev_object) && port_associate(d_portfd, PORT_SOURCE_FD, d_pevents[n].portev_object, 
-                        POLLIN, 0) < 0)
-        throw FDMultiplexerException("Unable to add fd back to ports (read): "+stringerror());
-      continue; // so we don't find ourselves as writable again
-    }
+  d_inrun = true;
 
-    d_iter=d_writeCallbacks.find(d_pevents[n].portev_object);
-    
-    if(d_iter != d_writeCallbacks.end()) {
-      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
-      if(d_writeCallbacks.count(d_pevents[n].portev_object) && port_associate(d_portfd, PORT_SOURCE_FD, d_pevents[n].portev_object, 
-                        POLLOUT, 0) < 0)
-        throw FDMultiplexerException("Unable to add fd back to ports (write): "+stringerror());
+  for (unsigned int n = 0; n < numevents; ++n) {
+    if (d_pevents[n].portev_events & POLLIN || d_pevents[n].portev_events & POLLER || d_pevents[n].portev_events & POLLHUP) {
+      const auto& iter = d_readCallbacks.find(d_pevents[n].portev_object);
+      if (iter != d_readCallbacks.end()) {
+        iter->d_callback(iter->d_fd, iter->d_parameter);
+        if (d_readCallbacks.count(d_pevents[n].portev_object) && port_associate(d_portfd, PORT_SOURCE_FD, d_pevents[n].portev_object, d_writeCallbacks.count(d_pevents[n].portev_object) ? POLLIN | POLLOUT : POLLIN, 0) < 0) {
+          throw FDMultiplexerException("Unable to add fd back to ports (read): " + stringerror());
+        }
+      }
+    }
+    if (d_pevents[n].portev_events & POLLOUT || d_pevents[n].portev_events & POLLER) {
+      const auto& iter = d_writeCallbacks.find(d_pevents[n].portev_object);
+      if (iter != d_writeCallbacks.end()) {
+        iter->d_callback(iter->d_fd, iter->d_parameter);
+        if (d_writeCallbacks.count(d_pevents[n].portev_object) && port_associate(d_portfd, PORT_SOURCE_FD, d_pevents[n].portev_object, d_readCallbacks.count(d_pevents[n].portev_object) ? POLLIN | POLLOUT : POLLOUT, 0) < 0) {
+          throw FDMultiplexerException("Unable to add fd back to ports (write): " + stringerror());
+        }
+      }
     }
-
   }
 
-  d_inrun=false;
+  d_inrun = false;
   return numevents;
 }
 
@@ -203,7 +222,7 @@ void acceptData(int fd, boost::any& parameter)
 int main()
 {
   Socket s(AF_INET, SOCK_DGRAM);
-  
+
   IPEndpoint loc("0.0.0.0", 2000);
   s.bind(loc);
 
@@ -218,5 +237,3 @@ int main()
   sfm.removeReadFD(s.getHandle());
 }
 #endif
-
-
index 83f3713151155a620be18c7eccb690483f2b51ee..08be6a3882543b6695b5273378eb2555379be927 100644 (file)
@@ -10,7 +10,8 @@
 
 BOOST_AUTO_TEST_SUITE(mplexer)
 
-BOOST_AUTO_TEST_CASE(test_MPlexer) {
+BOOST_AUTO_TEST_CASE(test_MPlexer)
+{
   auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
   BOOST_REQUIRE(mplexer != nullptr);
 
@@ -37,10 +38,10 @@ BOOST_AUTO_TEST_CASE(test_MPlexer) {
 
   bool writeCBCalled = false;
   auto writeCB = [](int fd, FDMultiplexer::funcparam_t param) {
-                        auto calledPtr = boost::any_cast<bool*>(param);
-                        BOOST_REQUIRE(calledPtr != nullptr);
-                        *calledPtr = true;
-                 };
+    auto calledPtr = boost::any_cast<bool*>(param);
+    BOOST_REQUIRE(calledPtr != nullptr);
+    *calledPtr = true;
+  };
   mplexer->addWriteFD(pipes[1],
                       writeCB,
                       &writeCBCalled,
@@ -85,10 +86,10 @@ BOOST_AUTO_TEST_CASE(test_MPlexer) {
 
   bool readCBCalled = false;
   auto readCB = [](int fd, FDMultiplexer::funcparam_t param) {
-                        auto calledPtr = boost::any_cast<bool*>(param);
-                        BOOST_REQUIRE(calledPtr != nullptr);
-                        *calledPtr = true;
-                };
+    auto calledPtr = boost::any_cast<bool*>(param);
+    BOOST_REQUIRE(calledPtr != nullptr);
+    *calledPtr = true;
+  };
   mplexer->addReadFD(pipes[0],
                      readCB,
                      &readCBCalled,
@@ -205,5 +206,81 @@ BOOST_AUTO_TEST_CASE(test_MPlexer) {
   close(pipes[1]);
 }
 
+BOOST_AUTO_TEST_CASE(test_MPlexer_ReadAndWrite)
+{
+  auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
+  BOOST_REQUIRE(mplexer != nullptr);
+
+  int sockets[2];
+  int res = socketpair(AF_UNIX, SOCK_STREAM, 0, sockets);
+  BOOST_REQUIRE_EQUAL(res, 0);
+  BOOST_REQUIRE_EQUAL(setNonBlocking(sockets[0]), true);
+  BOOST_REQUIRE_EQUAL(setNonBlocking(sockets[1]), true);
+
+  struct timeval now;
+  std::vector<int> readyFDs;
+  struct timeval ttd = now;
+  ttd.tv_sec += 5;
+
+  bool readCBCalled = false;
+  bool writeCBCalled = false;
+  auto readCB = [](int fd, FDMultiplexer::funcparam_t param) {
+    auto calledPtr = boost::any_cast<bool*>(param);
+    BOOST_REQUIRE(calledPtr != nullptr);
+    *calledPtr = true;
+  };
+  auto writeCB = [](int fd, FDMultiplexer::funcparam_t param) {
+    auto calledPtr = boost::any_cast<bool*>(param);
+    BOOST_REQUIRE(calledPtr != nullptr);
+    *calledPtr = true;
+  };
+  mplexer->addReadFD(sockets[0],
+                     readCB,
+                     &readCBCalled,
+                     &ttd);
+  mplexer->addWriteFD(sockets[0],
+                      writeCB,
+                      &writeCBCalled,
+                      &ttd);
+
+  /* not ready for reading yet, but should be writable */
+  readyFDs.clear();
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_REQUIRE_EQUAL(readyFDs.size(), 1U);
+  BOOST_CHECK_EQUAL(readyFDs.at(0), sockets[0]);
+
+  /* let's make the socket readable */
+  BOOST_REQUIRE_EQUAL(write(sockets[1], "0", 1), 1);
+
+  readyFDs.clear();
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_REQUIRE_EQUAL(readyFDs.size(), 1U);
+  BOOST_CHECK_EQUAL(readyFDs.at(0), sockets[0]);
+
+  auto ready = mplexer->run(&now, 100);
+  BOOST_CHECK_EQUAL(ready, 1);
+  BOOST_CHECK_EQUAL(readCBCalled, true);
+  BOOST_CHECK_EQUAL(writeCBCalled, true);
+
+  /* check that the write cb remains when we remove the read one */
+  mplexer->removeReadFD(sockets[0]);
+
+  readCBCalled = false;
+  writeCBCalled = false;
+  readyFDs.clear();
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_REQUIRE_EQUAL(readyFDs.size(), 1U);
+  BOOST_CHECK_EQUAL(readyFDs.at(0), sockets[0]);
+  ready = mplexer->run(&now, 100);
+  BOOST_CHECK_EQUAL(ready, 1);
+  BOOST_CHECK_EQUAL(readCBCalled, false);
+  BOOST_CHECK_EQUAL(writeCBCalled, true);
+
+  mplexer->removeWriteFD(sockets[0]);
+
+  /* clean up */
+  close(sockets[0]);
+  close(sockets[1]);
+}
 
 BOOST_AUTO_TEST_SUITE_END()