From: Otto Date: Tue, 27 Jul 2021 12:29:23 +0000 (+0200) Subject: Move to a shared PacketID X-Git-Tag: dnsdist-1.7.0-alpha1~74^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a2eef0008fcae7d515624d81818e1025457d02c6;p=thirdparty%2Fpdns.git Move to a shared PacketID There is an issue with equal_range() in asendto() that I fixed with an extra test for the moment. --- diff --git a/pdns/mtasker.cc b/pdns/mtasker.cc index 96bf9e0125..8776a4e33f 100644 --- a/pdns/mtasker.cc +++ b/pdns/mtasker.cc @@ -240,10 +240,9 @@ templateint MTasker::sendEven typename waiters_t::iterator waiter=d_waiters.find(key); if(waiter == d_waiters.end()) { - // cout<<"Event sent nobody was waiting for!"<& ha TCPIOHandlerStateChange(IOState::Done, state, pident); PacketBuffer packet; - int ret = MT->waitEvent(*pident, &packet, g_networkTimeoutMsec); + int ret = MT->waitEvent(pident, &packet, g_networkTimeoutMsec); TCPLOG(pident->tcpsock, "asendtcp waitEvent returned " << ret << ' ' << packet.size() << '/' << data.size() << ' '); if (ret == 0) { TCPLOG(pident->tcpsock, "timeout" << endl); @@ -505,7 +505,7 @@ LWResult::Result arecvtcp(PacketBuffer& data, const size_t len, shared_ptrlowState TCPIOHandlerStateChange(IOState::Done, state, pident); - int ret = MT->waitEvent(*pident, &data, g_networkTimeoutMsec); + int ret = MT->waitEvent(pident, &data, g_networkTimeoutMsec); TCPLOG(pident->tcpsock, "arecvtcp " << ret << ' ' << data.size() << ' ' ); if (ret == 0) { TCPLOG(pident->tcpsock, "timeout" << endl); @@ -529,15 +529,15 @@ LWResult::Result arecvtcp(PacketBuffer& data, const size_t len, shared_ptr(&var); + std::shared_ptr pident = boost::any_cast>(var); PacketBuffer resp; resp.resize(512); ComboAddress fromaddr; socklen_t addrlen = sizeof(fromaddr); ssize_t ret = recvfrom(fd, resp.data(), resp.size(), 0, (sockaddr *)&fromaddr, &addrlen); - if (fromaddr != pident.remote) { - g_log<remote) { + g_log<remote.toStringWithPort()<<"), discarding"< pident = std::make_shared(); + pident->fd = s.getHandle(); + pident->remote = dest; + pident->type = 0; t_fdm->addReadFD(s.getHandle(), handleGenUDPQueryResponse, pident); PacketBuffer data; int ret=MT->waitEvent(pident, &data, g_networkTimeoutMsec); - if(!ret || ret==-1) { // timeout + if (!ret || ret==-1) { // timeout t_fdm->removeReadFD(s.getHandle()); } else if(data.empty()) {// error, EOF or other @@ -731,24 +731,25 @@ LWResult::Result asendto(const char *data, size_t len, int flags, const ComboAddress& toaddr, uint16_t id, const DNSName& domain, uint16_t qtype, int* fd) { - PacketID pident; - pident.domain = domain; - pident.remote = toaddr; - pident.type = qtype; + auto pident = std::make_shared(); + pident->domain = domain; + pident->remote = toaddr; + pident->type = qtype; // see if there is an existing outstanding request we can chain on to, using partial equivalence function looking for the same // query (qname and qtype) to the same host, but with a different message ID pair chain=MT->d_waiters.equal_range(pident, PacketIDBirthdayCompare()); for(; chain.first != chain.second; chain.first++) { - if(chain.first->key.fd > -1 && !chain.first->key.closed) { // don't chain onto existing chained waiter or a chain already processed - /* - cerr<<"Orig: "<key.domain<<", "<key.remote.toString()<<", id="<key.id - <<", count="<key.chain.size()<<", origfd: "<key.fd<key.chain.insert(id); // we can chain - *fd=-1; // gets used in waitEvent / sendEvent later on + assert(chain.first->key->domain == pident->domain); + if (chain.first->key->domain != pident->domain) { + // XXX Actually, this should not happen..., but it does + continue; + } + if(chain.first->key->fd > -1 && !chain.first->key->closed) { // don't chain onto existing chained waiter or a chain already processed + //cerr << "Insert " << id << ' ' << pident << " into chain for " << chain.first->key << endl; + chain.first->key->chain.insert(id); // we can chain + *fd = -1; // gets used in waitEvent / sendEvent later on return LWResult::Result::Success; } } @@ -758,8 +759,8 @@ LWResult::Result asendto(const char *data, size_t len, int flags, return ret; } - pident.fd=*fd; - pident.id=id; + pident->fd=*fd; + pident->id=id; t_fdm->addReadFD(*fd, handleUDPServerResponse, pident); ssize_t sent = send(*fd, data, len, 0); @@ -780,12 +781,12 @@ LWResult::Result arecvfrom(PacketBuffer& packet, int flags, const ComboAddress& { static const unsigned int nearMissLimit = ::arg().asNum("spoof-nearmiss-max"); - PacketID pident; - pident.fd=fd; - pident.id=id; - pident.domain=domain; - pident.type = qtype; - pident.remote=fromaddr; + auto pident = std::make_shared(); + pident->fd = fd; + pident->id = id; + pident->domain = domain; + pident->type = qtype; + pident->remote = fromaddr; int ret=MT->waitEvent(pident, &packet, g_networkTimeoutMsec, now); @@ -798,10 +799,10 @@ LWResult::Result arecvfrom(PacketBuffer& packet, int flags, const ComboAddress& *d_len=packet.size(); - if (nearMissLimit > 0 && pident.nearMisses > nearMissLimit) { + if (nearMissLimit > 0 && pident->nearMisses > nearMissLimit) { /* we have received more than nearMissLimit answers on the right IP and port, from the right source (we are using connected sockets), for the correct qname and qtype, but with an unexpected message ID. That looks like a spoofing attempt. */ - g_log< "<nearMisses<<" > "<>(var); + std::shared_ptr pid = boost::any_cast>(var); assert(pid->tcphandler); assert(fd == pid->tcphandler->getDescriptor()); IOState newstate = IOState::Done; @@ -4213,7 +4214,7 @@ static void TCPIOHandlerIO(int fd, FDMultiplexer::funcparam_t& var) pid->inMSG.resize(pid->inPos); // old content (if there) + new bytes read, only relevant for the inIncompleteOkay case newstate = IOState::Done; TCPIOHandlerStateChange(pid->lowState, newstate, pid); - MT->sendEvent(*pid, &pid->inMSG); + MT->sendEvent(pid, &pid->inMSG); return; } break; @@ -4226,7 +4227,7 @@ static void TCPIOHandlerIO(int fd, FDMultiplexer::funcparam_t& var) TCPLOG(pid->tcpsock, "read exception..." << e.what() << endl); PacketBuffer empty; TCPIOHandlerStateChange(pid->lowState, newstate, pid); - MT->sendEvent(*pid, &empty); // this conveys error status + MT->sendEvent(pid, &empty); // this conveys error status return; } break; @@ -4234,14 +4235,14 @@ static void TCPIOHandlerIO(int fd, FDMultiplexer::funcparam_t& var) case TCPAction::DoingWrite: TCPLOG(pid->tcpsock, "highState: Writing" << endl); try { - TCPLOG(pid->tcpsock, "tryWrite: " << pid->outPos << '/' << pid->outMSG.size() << ' ' << pid << " -> "); + TCPLOG(pid->tcpsock, "tryWrite: " << pid->outPos << '/' << pid->outMSG.size() << ' ' << " -> "); newstate = pid->tcphandler->tryWrite(pid->outMSG, pid->outPos, pid->outMSG.size()); TCPLOG(pid->tcpsock, pid->outPos << '/' << pid->outMSG.size() << endl); switch (newstate) { case IOState::Done: { TCPLOG(pid->tcpsock, "tryWrite: Done" << endl); TCPIOHandlerStateChange(pid->lowState, newstate, pid); - MT->sendEvent(*pid, &pid->outMSG); // send back what we sent to convey everything is ok + MT->sendEvent(pid, &pid->outMSG); // send back what we sent to convey everything is ok return; } case IOState::NeedRead: @@ -4257,7 +4258,7 @@ static void TCPIOHandlerIO(int fd, FDMultiplexer::funcparam_t& var) TCPLOG(pid->tcpsock, "write exception..." << e.what() << endl); PacketBuffer sent; TCPIOHandlerStateChange(pid->lowState, newstate, pid); - MT->sendEvent(*pid, &sent); // we convey error status by sending empty string + MT->sendEvent(pid, &sent); // we convey error status by sending empty string return; } break; @@ -4268,27 +4269,25 @@ static void TCPIOHandlerIO(int fd, FDMultiplexer::funcparam_t& var) } // resend event to everybody chained onto it -static void doResends(MT_t::waiters_t::iterator& iter, PacketID resend, const PacketBuffer& content) +static void doResends(MT_t::waiters_t::iterator& iter, const std::shared_ptr& resend, const PacketBuffer& content) { // We close the chain for new entries, since they won't be processed anyway - iter->key.closed = true; + iter->key->closed = true; - if(iter->key.chain.empty()) + if(iter->key->chain.empty()) return; - // cerr<<"doResends called!\n"; - for(PacketID::chain_t::iterator i=iter->key.chain.begin(); i != iter->key.chain.end() ; ++i) { - resend.fd=-1; - resend.id=*i; - // cerr<<"\tResending "<domain.clear(); + pident->type = 0; } else { try { if(len > 12) - pident.domain=DNSName(reinterpret_cast(packet.data()), len, 12, false, &pident.type); // don't copy this from above - we need to do the actual read + pident->domain=DNSName(reinterpret_cast(packet.data()), len, 12, false, &pident->type); // don't copy this from above - we need to do the actual read } catch(std::exception& e) { g_stats.serverParseError++; // won't be fed to lwres.cc, so we have to increment @@ -4360,26 +4359,26 @@ retryWithName: // we do a full scan for outstanding queries on unexpected answers. not too bad since we only accept them on the right port number, which is hard enough to guess for (MT_t::waiters_t::iterator mthread = MT->d_waiters.begin(); mthread != MT->d_waiters.end(); ++mthread) { - if (pident.fd == mthread->key.fd && mthread->key.remote == pident.remote && mthread->key.type == pident.type && - pident.domain == mthread->key.domain) { + if (pident->fd == mthread->key->fd && mthread->key->remote == pident->remote && mthread->key->type == pident->type && + pident->domain == mthread->key->domain) { /* we are expecting an answer from that exact source, on that exact port (since we are using connected sockets), for that qname/qtype, but with a different message ID. That smells like a spoofing attempt. For now we will just increase the counter and will deal with that later. */ - mthread->key.nearMisses++; + mthread->key->nearMisses++; } // be a bit paranoid here since we're weakening our matching - if(pident.domain.empty() && !mthread->key.domain.empty() && !pident.type && mthread->key.type && - pident.id == mthread->key.id && mthread->key.remote == pident.remote) { + if(pident->domain.empty() && !mthread->key->domain.empty() && !pident->type && mthread->key->type && + pident->id == mthread->key->id && mthread->key->remote == pident->remote) { // cerr<<"Empty response, rest matches though, sending to a waiter"<key.domain; - pident.type = mthread->key.type; + pident->domain = mthread->key->domain; + pident->type = mthread->key->type; goto retryWithName; // note that this only passes on an error, lwres will still reject the packet } } g_stats.unexpectedCount++; // if we made it here, it really is an unexpected answer if(g_logCommonErrors) { - g_log<" : pident.domain.toString())<<", "<d_waiters.size()<<" waiters"<domain.empty() ? "" : pident->domain.toString())<<", "<type<<", "<d_waiters.size()<<" waiters"<= 0) { @@ -5460,7 +5459,7 @@ try t_bogusqueryring = std::unique_ptr > >(new boost::circular_buffer >()); t_bogusqueryring->set_capacity(ringsize); } - MT=std::unique_ptr >(new MTasker(::arg().asNum("stack-size"))); + MT=std::unique_ptr,PacketBuffer> >(new MTasker,PacketBuffer>(::arg().asNum("stack-size"))); threadInfo.mt = MT.get(); /* start protobuf export threads if needed */ diff --git a/pdns/rec_channel_rec.cc b/pdns/rec_channel_rec.cc index e1eb007f58..e4306cf737 100644 --- a/pdns/rec_channel_rec.cc +++ b/pdns/rec_channel_rec.cc @@ -966,12 +966,12 @@ static string* pleaseGetCurrentQueries() ostr << (fmt % "qname" % "qtype" % "remote" % "tcp" % "chained" % "spent(ms)"); unsigned int n=0; for(const auto& mthread : getMT()->d_waiters) { - const PacketID& pident = mthread.key; + const std::shared_ptr& pident = mthread.key; const double spent = g_networkTimeoutMsec - (DiffTime(now, mthread.ttd) * 1000); ostr << (fmt - % pident.domain.toLogString() /* ?? */ % DNSRecordContent::NumberToType(pident.type) - % pident.remote.toString() % (pident.tcpsock ? 'Y' : 'n') - % (pident.fd == -1 ? 'Y' : 'n') + % pident->domain.toLogString() /* ?? */ % DNSRecordContent::NumberToType(pident->type) + % pident->remote.toString() % (pident->tcpsock ? 'Y' : 'n') + % (pident->fd == -1 ? 'Y' : 'n') % (spent > 0 ? spent : '0') ); ++n; diff --git a/pdns/syncres.hh b/pdns/syncres.hh index c91b77ce90..ac08089d47 100644 --- a/pdns/syncres.hh +++ b/pdns/syncres.hh @@ -967,12 +967,10 @@ struct PacketID bool operator<(const PacketID& b) const { - int ourSock= tcpsock; - int bSock = b.tcpsock; - if (tie(remote, ourSock, type) < tie(b.remote, bSock, b.type)) { + if (tie(remote, tcpsock, type) < tie(b.remote, b.tcpsock, b.type)) { return true; } - if (tie(remote, ourSock, type) > tie(b.remote, bSock, b.type)) { + if (tie(remote, tcpsock, type) > tie(b.remote, b.tcpsock, b.type)) { return false; } @@ -980,25 +978,38 @@ struct PacketID } }; -struct PacketIDBirthdayCompare: public std::binary_function +inline ostream& operator<<(ostream & os, const PacketID& pid) { - bool operator()(const PacketID& a, const PacketID& b) const + return os << "PacketID(id=" << pid.id << ",remote=" << pid.remote.toString() << ",type=" << pid.type << ",tcpsock=" << + pid.tcpsock << "fd=" << pid.fd << ',' << pid.domain << ')'; +} + +inline ostream& operator<<(ostream & os, const shared_ptr& pid) +{ + return os << *pid; +} + +inline bool operator<(const std::shared_ptr& a, const std::shared_ptr& b) +{ + return a->operator<(*b); +} + +struct PacketIDBirthdayCompare +{ + bool operator()(const std::shared_ptr& a, const std::shared_ptr& b) const { - int ourSock= a.tcpsock; - int bSock = b.tcpsock; - if (tie(a.remote, ourSock, a.type) < tie(b.remote, bSock, b.type)) { + if (tie(a->remote, a->tcpsock, a->type) < tie(b->remote, b->tcpsock, b->type)) { return true; } - if (tie(a.remote, ourSock, a.type) > tie(b.remote, bSock, b.type)) { + if (tie(a->remote, a->tcpsock, a->type) > tie(b->remote, b->tcpsock, b->type)) { return false; } - - return a.domain < b.domain; + return a->domain < b->domain; } }; extern std::unique_ptr g_recCache; extern thread_local std::unique_ptr t_packetCache; -typedef MTasker MT_t; +typedef MTasker,PacketBuffer> MT_t; MT_t* getMT(); struct RecursorStats