]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Move to a shared PacketID
authorOtto <otto.moerbeek@open-xchange.com>
Tue, 27 Jul 2021 12:29:23 +0000 (14:29 +0200)
committerOtto <otto.moerbeek@open-xchange.com>
Mon, 2 Aug 2021 10:21:38 +0000 (12:21 +0200)
There is an issue with equal_range() in asendto() that I fixed with an extra
test for the moment.

pdns/mtasker.cc
pdns/pdns_recursor.cc
pdns/rec_channel_rec.cc
pdns/syncres.hh

index 96bf9e0125242e44238e44f4200ed50d49bc9685..8776a4e33f3cc75c2719e308e38f637eb84e9541 100644 (file)
@@ -240,10 +240,9 @@ template<class EventKey, class EventVal>int MTasker<EventKey,EventVal>::sendEven
   typename waiters_t::iterator waiter=d_waiters.find(key);
 
   if(waiter == d_waiters.end()) {
-    //    cout<<"Event sent nobody was waiting for!"<<endl;
+    //cerr<<"Event sent nobody was waiting for! " <<key << endl;
     return 0;
   }
-  
   d_waitstatus=Answer;
   if(val)
     d_waitval=*val;
index 541495f043bc5a88942d54b6d9346a76afeaeaf5..549b00bb4b1d27e0510d31ed0ca4c555ec846272 100644 (file)
@@ -438,7 +438,7 @@ LWResult::Result asendtcp(const PacketBuffer& data, shared_ptr<TCPIOHandler>& ha
   TCPIOHandlerStateChange(IOState::Done, state, pident);
 
   PacketBuffer packet;
-  int ret = MT->waitEvent(*pident, &packet, g_networkTimeoutMsec);
+  int ret = MT->waitEvent(pident, &packet, g_networkTimeoutMsec);
   TCPLOG(pident->tcpsock, "asendtcp waitEvent returned " << ret << ' ' << packet.size() << '/' << data.size() << ' ');
   if (ret == 0) {
     TCPLOG(pident->tcpsock, "timeout" << endl);
@@ -505,7 +505,7 @@ LWResult::Result arecvtcp(PacketBuffer& data, const size_t len, shared_ptr<TCPIO
   // Will set pident->lowState
   TCPIOHandlerStateChange(IOState::Done, state, pident);
 
-  int ret = MT->waitEvent(*pident, &data, g_networkTimeoutMsec);
+  int ret = MT->waitEvent(pident, &data, g_networkTimeoutMsec);
   TCPLOG(pident->tcpsock, "arecvtcp " << ret << ' ' << data.size() << ' ' );
   if (ret == 0) {
     TCPLOG(pident->tcpsock, "timeout" << endl);
@@ -529,15 +529,15 @@ LWResult::Result arecvtcp(PacketBuffer& data, const size_t len, shared_ptr<TCPIO
 
 static void handleGenUDPQueryResponse(int fd, FDMultiplexer::funcparam_t& var)
 {
-  PacketID pident = *boost::any_cast<PacketID>(&var);
+  std::shared_ptr<PacketID> pident = boost::any_cast<std::shared_ptr<PacketID>>(var);
   PacketBuffer resp;
   resp.resize(512);
   ComboAddress fromaddr;
   socklen_t addrlen = sizeof(fromaddr);
 
   ssize_t ret = recvfrom(fd, resp.data(), resp.size(), 0, (sockaddr *)&fromaddr, &addrlen);
-  if (fromaddr != pident.remote) {
-    g_log<<Logger::Notice<<"Response received from the wrong remote host ("<<fromaddr.toStringWithPort()<<" instead of "<<pident.remote.toStringWithPort()<<"), discarding"<<endl;
+  if (fromaddr != pident->remote) {
+    g_log<<Logger::Notice<<"Response received from the wrong remote host ("<<fromaddr.toStringWithPort()<<" instead of "<<pident->remote.toStringWithPort()<<"), discarding"<<endl;
 
   }
 
@@ -562,16 +562,16 @@ PacketBuffer GenUDPQueryResponse(const ComboAddress& dest, const string& query)
   s.connect(dest);
   s.send(query);
 
-  PacketID pident;
-  pident.fd=s.getHandle();
-  pident.remote=dest;
-  pident.type=0;
+  std::shared_ptr<PacketID> pident = std::make_shared<PacketID>();
+  pident->fd = s.getHandle();
+  pident->remote = dest;
+  pident->type = 0;
   t_fdm->addReadFD(s.getHandle(), handleGenUDPQueryResponse, pident);
 
   PacketBuffer data;
   int ret=MT->waitEvent(pident, &data, g_networkTimeoutMsec);
 
-  if(!ret || ret==-1) { // timeout
+  if (!ret || ret==-1) { // timeout
     t_fdm->removeReadFD(s.getHandle());
   }
   else if(data.empty()) {// error, EOF or other
@@ -731,24 +731,25 @@ LWResult::Result asendto(const char *data, size_t len, int flags,
                          const ComboAddress& toaddr, uint16_t id, const DNSName& domain, uint16_t qtype, int* fd)
 {
 
-  PacketID pident;
-  pident.domain = domain;
-  pident.remote = toaddr;
-  pident.type = qtype;
+  auto pident = std::make_shared<PacketID>();
+  pident->domain = domain;
+  pident->remote = toaddr;
+  pident->type = qtype;
 
   // see if there is an existing outstanding request we can chain on to, using partial equivalence function looking for the same
   // query (qname and qtype) to the same host, but with a different message ID
   pair<MT_t::waiters_t::iterator, MT_t::waiters_t::iterator> chain=MT->d_waiters.equal_range(pident, PacketIDBirthdayCompare());
 
   for(; chain.first != chain.second; chain.first++) {
-    if(chain.first->key.fd > -1 && !chain.first->key.closed) { // don't chain onto existing chained waiter or a chain already processed
-      /*
-      cerr<<"Orig: "<<pident.domain<<", "<<pident.remote.toString()<<", id="<<id<<endl;
-      cerr<<"Had hit: "<< chain.first->key.domain<<", "<<chain.first->key.remote.toString()<<", id="<<chain.first->key.id
-          <<", count="<<chain.first->key.chain.size()<<", origfd: "<<chain.first->key.fd<<endl;
-      */
-      chain.first->key.chain.insert(id); // we can chain
-      *fd=-1;                            // gets used in waitEvent / sendEvent later on
+    assert(chain.first->key->domain == pident->domain);
+    if (chain.first->key->domain != pident->domain) {
+      // XXX Actually, this should not happen..., but it does
+      continue;
+    }
+    if(chain.first->key->fd > -1 && !chain.first->key->closed) { // don't chain onto existing chained waiter or a chain already processed
+      //cerr << "Insert " << id << ' ' << pident << " into chain for  " << chain.first->key << endl;
+      chain.first->key->chain.insert(id); // we can chain
+      *fd = -1;                            // gets used in waitEvent / sendEvent later on
       return LWResult::Result::Success;
     }
   }
@@ -758,8 +759,8 @@ LWResult::Result asendto(const char *data, size_t len, int flags,
     return ret;
   }
 
-  pident.fd=*fd;
-  pident.id=id;
+  pident->fd=*fd;
+  pident->id=id;
 
   t_fdm->addReadFD(*fd, handleUDPServerResponse, pident);
   ssize_t sent = send(*fd, data, len, 0);
@@ -780,12 +781,12 @@ LWResult::Result arecvfrom(PacketBuffer& packet, int flags, const ComboAddress&
 {
   static const unsigned int nearMissLimit = ::arg().asNum("spoof-nearmiss-max");
 
-  PacketID pident;
-  pident.fd=fd;
-  pident.id=id;
-  pident.domain=domain;
-  pident.type = qtype;
-  pident.remote=fromaddr;
+  auto pident = std::make_shared<PacketID>();
+  pident->fd = fd;
+  pident->id = id;
+  pident->domain = domain;
+  pident->type = qtype;
+  pident->remote = fromaddr;
 
   int ret=MT->waitEvent(pident, &packet, g_networkTimeoutMsec, now);
 
@@ -798,10 +799,10 @@ LWResult::Result arecvfrom(PacketBuffer& packet, int flags, const ComboAddress&
 
     *d_len=packet.size();
 
-    if (nearMissLimit > 0 && pident.nearMisses > nearMissLimit) {
+    if (nearMissLimit > 0 && pident->nearMisses > nearMissLimit) {
       /* we have received more than nearMissLimit answers on the right IP and port, from the right source (we are using connected sockets),
          for the correct qname and qtype, but with an unexpected message ID. That looks like a spoofing attempt. */
-      g_log<<Logger::Error<<"Too many ("<<pident.nearMisses<<" > "<<nearMissLimit<<") answers with a wrong message ID for '"<<domain<<"' from "<<fromaddr.toString()<<", assuming spoof attempt."<<endl;
+      g_log<<Logger::Error<<"Too many ("<<pident->nearMisses<<" > "<<nearMissLimit<<") answers with a wrong message ID for '"<<domain<<"' from "<<fromaddr.toString()<<", assuming spoof attempt."<<endl;
       g_stats.spoofCount++;
       return LWResult::Result::Spoofed;
     }
@@ -4187,7 +4188,7 @@ static void TCPIOHandlerStateChange(IOState oldstate, IOState newstate, std::sha
 
 static void TCPIOHandlerIO(int fd, FDMultiplexer::funcparam_t& var)
 {
-  auto pid = boost::any_cast<std::shared_ptr<PacketID>>(var);
+  std::shared_ptr<PacketID> pid = boost::any_cast<std::shared_ptr<PacketID>>(var);
   assert(pid->tcphandler);
   assert(fd == pid->tcphandler->getDescriptor());
   IOState newstate = IOState::Done;
@@ -4213,7 +4214,7 @@ static void TCPIOHandlerIO(int fd, FDMultiplexer::funcparam_t& var)
           pid->inMSG.resize(pid->inPos); // old content (if there) + new bytes read, only relevant for the inIncompleteOkay case
           newstate = IOState::Done;
           TCPIOHandlerStateChange(pid->lowState, newstate, pid);
-          MT->sendEvent(*pid, &pid->inMSG);
+          MT->sendEvent(pid, &pid->inMSG);
           return;
         }
         break;
@@ -4226,7 +4227,7 @@ static void TCPIOHandlerIO(int fd, FDMultiplexer::funcparam_t& var)
       TCPLOG(pid->tcpsock, "read exception..." << e.what() << endl);
       PacketBuffer empty;
       TCPIOHandlerStateChange(pid->lowState, newstate, pid);
-      MT->sendEvent(*pid, &empty); // this conveys error status
+      MT->sendEvent(pid, &empty); // this conveys error status
       return;
     }
     break;
@@ -4234,14 +4235,14 @@ static void TCPIOHandlerIO(int fd, FDMultiplexer::funcparam_t& var)
   case TCPAction::DoingWrite:
     TCPLOG(pid->tcpsock, "highState: Writing" << endl);
     try {
-      TCPLOG(pid->tcpsock, "tryWrite: " << pid->outPos << '/' << pid->outMSG.size() << ' ' << pid << " -> ");
+      TCPLOG(pid->tcpsock, "tryWrite: " << pid->outPos << '/' << pid->outMSG.size() << ' ' << " -> ");
       newstate = pid->tcphandler->tryWrite(pid->outMSG, pid->outPos, pid->outMSG.size());
       TCPLOG(pid->tcpsock, pid->outPos << '/' << pid->outMSG.size() << endl);
       switch (newstate) {
       case IOState::Done: {
         TCPLOG(pid->tcpsock, "tryWrite: Done" << endl);
         TCPIOHandlerStateChange(pid->lowState, newstate, pid);
-        MT->sendEvent(*pid, &pid->outMSG); // send back what we sent to convey everything is ok
+        MT->sendEvent(pid, &pid->outMSG); // send back what we sent to convey everything is ok
         return;
       }
       case IOState::NeedRead:
@@ -4257,7 +4258,7 @@ static void TCPIOHandlerIO(int fd, FDMultiplexer::funcparam_t& var)
       TCPLOG(pid->tcpsock, "write exception..." << e.what() << endl);
       PacketBuffer sent;
       TCPIOHandlerStateChange(pid->lowState, newstate, pid);
-      MT->sendEvent(*pid, &sent); // we convey error status by sending empty string
+      MT->sendEvent(pid, &sent); // we convey error status by sending empty string
       return;
     }
     break;
@@ -4268,27 +4269,25 @@ static void TCPIOHandlerIO(int fd, FDMultiplexer::funcparam_t& var)
 }
 
 // resend event to everybody chained onto it
-static void doResends(MT_t::waiters_t::iterator& iter, PacketID resend, const PacketBuffer& content)
+static void doResends(MT_t::waiters_t::iterator& iter, const std::shared_ptr<PacketID>& resend, const PacketBuffer& content)
 {
   // We close the chain for new entries, since they won't be processed anyway
-  iter->key.closed = true;
+  iter->key->closed = true;
 
-  if(iter->key.chain.empty())
+  if(iter->key->chain.empty())
     return;
-  //  cerr<<"doResends called!\n";
-  for(PacketID::chain_t::iterator i=iter->key.chain.begin(); i != iter->key.chain.end() ; ++i) {
-    resend.fd=-1;
-    resend.id=*i;
-    //    cerr<<"\tResending "<<content.size()<<" bytes for fd="<<resend.fd<<" and id="<<resend.id<<endl;
-
-    MT->sendEvent(resend, &content);
+  for(PacketID::chain_t::iterator i=iter->key->chain.begin(); i != iter->key->chain.end() ; ++i) {
+    auto r = std::make_shared<PacketID>(*resend);
+    r->fd = -1;
+    r->id = *i;
+    MT->sendEvent(r, &content);
     g_stats.chainResends++;
   }
 }
 
 static void handleUDPServerResponse(int fd, FDMultiplexer::funcparam_t& var)
 {
-  PacketID pid=boost::any_cast<PacketID>(var);
+  std::shared_ptr<PacketID> pid = boost::any_cast<std::shared_ptr<PacketID>>(var);
   ssize_t len;
   PacketBuffer packet;
   packet.resize(g_outgoingEDNSBufsize);
@@ -4322,10 +4321,10 @@ static void handleUDPServerResponse(int fd, FDMultiplexer::funcparam_t& var)
   dnsheader dh;
   memcpy(&dh, &packet.at(0), sizeof(dh));
 
-  PacketID pident;
-  pident.remote=fromaddr;
-  pident.id=dh.id;
-  pident.fd=fd;
+  auto pident = std::make_shared<PacketID>();
+  pident->remote = fromaddr;
+  pident->id = dh.id;
+  pident->fd = fd;
 
   if(!dh.qr && g_logCommonErrors) {
     g_log<<Logger::Notice<<"Not taking data from question on outgoing socket from "<< fromaddr.toStringWithPort()  <<endl;
@@ -4333,13 +4332,13 @@ static void handleUDPServerResponse(int fd, FDMultiplexer::funcparam_t& var)
 
   if(!dh.qdcount || // UPC, Nominum, very old BIND on FormErr, NSD
      !dh.qr) {      // one weird server
-    pident.domain.clear();
-    pident.type = 0;
+    pident->domain.clear();
+    pident->type = 0;
   }
   else {
     try {
       if(len > 12)
-        pident.domain=DNSName(reinterpret_cast<const char *>(packet.data()), len, 12, false, &pident.type); // don't copy this from above - we need to do the actual read
+        pident->domain=DNSName(reinterpret_cast<const char *>(packet.data()), len, 12, false, &pident->type); // don't copy this from above - we need to do the actual read
     }
     catch(std::exception& e) {
       g_stats.serverParseError++; // won't be fed to lwres.cc, so we have to increment
@@ -4360,26 +4359,26 @@ retryWithName:
 
     // we do a full scan for outstanding queries on unexpected answers. not too bad since we only accept them on the right port number, which is hard enough to guess
     for (MT_t::waiters_t::iterator mthread = MT->d_waiters.begin(); mthread != MT->d_waiters.end(); ++mthread) {
-      if (pident.fd == mthread->key.fd && mthread->key.remote == pident.remote &&  mthread->key.type == pident.type &&
-         pident.domain == mthread->key.domain) {
+      if (pident->fd == mthread->key->fd && mthread->key->remote == pident->remote &&  mthread->key->type == pident->type &&
+         pident->domain == mthread->key->domain) {
         /* we are expecting an answer from that exact source, on that exact port (since we are using connected sockets), for that qname/qtype,
            but with a different message ID. That smells like a spoofing attempt. For now we will just increase the counter and will deal with
            that later. */
-        mthread->key.nearMisses++;
+        mthread->key->nearMisses++;
       }
 
       // be a bit paranoid here since we're weakening our matching
-      if(pident.domain.empty() && !mthread->key.domain.empty() && !pident.type && mthread->key.type &&
-         pident.id  == mthread->key.id && mthread->key.remote == pident.remote) {
+      if(pident->domain.empty() && !mthread->key->domain.empty() && !pident->type && mthread->key->type &&
+         pident->id  == mthread->key->id && mthread->key->remote == pident->remote) {
         // cerr<<"Empty response, rest matches though, sending to a waiter"<<endl;
-        pident.domain = mthread->key.domain;
-        pident.type = mthread->key.type;
+        pident->domain = mthread->key->domain;
+        pident->type = mthread->key->type;
         goto retryWithName; // note that this only passes on an error, lwres will still reject the packet
       }
     }
     g_stats.unexpectedCount++; // if we made it here, it really is an unexpected answer
     if(g_logCommonErrors) {
-      g_log<<Logger::Warning<<"Discarding unexpected packet from "<<fromaddr.toStringWithPort()<<": "<< (pident.domain.empty() ? "<empty>" : pident.domain.toString())<<", "<<pident.type<<", "<<MT->d_waiters.size()<<" waiters"<<endl;
+      g_log<<Logger::Warning<<"Discarding unexpected packet from "<<fromaddr.toStringWithPort()<<": "<< (pident->domain.empty() ? "<empty>" : pident->domain.toString())<<", "<<pident->type<<", "<<MT->d_waiters.size()<<" waiters"<<endl;
     }
   }
   else if(fd >= 0) {
@@ -5460,7 +5459,7 @@ try
     t_bogusqueryring = std::unique_ptr<boost::circular_buffer<pair<DNSName, uint16_t> > >(new boost::circular_buffer<pair<DNSName, uint16_t> >());
     t_bogusqueryring->set_capacity(ringsize);
   }
-  MT=std::unique_ptr<MTasker<PacketID,PacketBuffer> >(new MTasker<PacketID,PacketBuffer>(::arg().asNum("stack-size")));
+  MT=std::unique_ptr<MTasker<std::shared_ptr<PacketID>,PacketBuffer> >(new MTasker<std::shared_ptr<PacketID>,PacketBuffer>(::arg().asNum("stack-size")));
   threadInfo.mt = MT.get();
 
   /* start protobuf export threads if needed */
index e1eb007f582c6236fdc4c626520dcba05a14ee4e..e4306cf73755e4a80cf77a3c31dc8aef671a2281 100644 (file)
@@ -966,12 +966,12 @@ static string* pleaseGetCurrentQueries()
   ostr << (fmt % "qname" % "qtype" % "remote" % "tcp" % "chained" % "spent(ms)");
   unsigned int n=0;
   for(const auto& mthread : getMT()->d_waiters) {
-    const PacketID& pident = mthread.key;
+    const std::shared_ptr<PacketID>& pident = mthread.key;
     const double spent = g_networkTimeoutMsec - (DiffTime(now, mthread.ttd) * 1000);
     ostr << (fmt 
-             % pident.domain.toLogString() /* ?? */ % DNSRecordContent::NumberToType(pident.type) 
-             % pident.remote.toString() % (pident.tcpsock ? 'Y' : 'n')
-             % (pident.fd == -1 ? 'Y' : 'n')
+             % pident->domain.toLogString() /* ?? */ % DNSRecordContent::NumberToType(pident->type) 
+             % pident->remote.toString() % (pident->tcpsock ? 'Y' : 'n')
+             % (pident->fd == -1 ? 'Y' : 'n')
              % (spent > 0 ? spent : '0')
              );
     ++n;
index c91b77ce90121ba48798210d93a3a57c5b35092b..ac08089d4706709543ef4383e201dffed81693f9 100644 (file)
@@ -967,12 +967,10 @@ struct PacketID
 
   bool operator<(const PacketID& b) const
   {
-    int ourSock= tcpsock;
-    int bSock = b.tcpsock;
-    if (tie(remote, ourSock, type) < tie(b.remote, bSock, b.type)) {
+    if (tie(remote, tcpsock, type) < tie(b.remote, b.tcpsock, b.type)) {
       return true;
     }
-    if (tie(remote, ourSock, type) > tie(b.remote, bSock, b.type)) {
+    if (tie(remote, tcpsock, type) > tie(b.remote, b.tcpsock, b.type)) {
       return false;
     }
 
@@ -980,25 +978,38 @@ struct PacketID
   }
 };
 
-struct PacketIDBirthdayCompare: public std::binary_function<PacketID, PacketID, bool>
+inline ostream& operator<<(ostream & os, const PacketID& pid)
 {
-  bool operator()(const PacketID& a, const PacketID& b) const
+  return os << "PacketID(id=" << pid.id << ",remote=" << pid.remote.toString() << ",type=" << pid.type << ",tcpsock=" <<
+    pid.tcpsock << "fd=" << pid.fd << ',' << pid.domain << ')';
+}
+
+inline ostream& operator<<(ostream & os, const shared_ptr<PacketID>& pid)
+{
+  return os << *pid;
+}
+
+inline bool operator<(const std::shared_ptr<PacketID>& a, const std::shared_ptr<PacketID>& b)
+{
+  return a->operator<(*b);
+}
+
+struct PacketIDBirthdayCompare
+{
+  bool operator()(const std::shared_ptr<PacketID>& a, const std::shared_ptr<PacketID>& b) const
   {
-    int ourSock= a.tcpsock;
-    int bSock = b.tcpsock;
-    if (tie(a.remote, ourSock, a.type) < tie(b.remote, bSock, b.type)) {
+    if (tie(a->remote, a->tcpsock, a->type) < tie(b->remote, b->tcpsock, b->type)) {
       return true;
     }
-    if (tie(a.remote, ourSock, a.type) > tie(b.remote, bSock, b.type)) {
+    if (tie(a->remote, a->tcpsock, a->type) > tie(b->remote, b->tcpsock, b->type)) {
       return false;
     }
-
-    return a.domain < b.domain;
+    return a->domain < b->domain;
   }
 };
 extern std::unique_ptr<MemRecursorCache> g_recCache;
 extern thread_local std::unique_ptr<RecursorPacketCache> t_packetCache;
-typedef MTasker<PacketID,PacketBuffer> MT_t;
+typedef MTasker<std::shared_ptr<PacketID>,PacketBuffer> MT_t;
 MT_t* getMT();
 
 struct RecursorStats