]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Refactor cache hit path and do cache check for tcp
authorOtto <otto.moerbeek@open-xchange.com>
Mon, 18 Jan 2021 15:53:06 +0000 (16:53 +0100)
committerOtto <otto.moerbeek@open-xchange.com>
Fri, 29 Jan 2021 12:45:41 +0000 (13:45 +0100)
pdns/pdns_recursor.cc
pdns/recpacketcache.cc
pdns/recpacketcache.hh
pdns/test-recpacketcache_cc.cc

index 6a124c5ab7fa06ca268fb6d48226b795616d2f93..422b7728a614a3012a4b1968ee8a66b9f736a8e1 100644 (file)
@@ -49,6 +49,7 @@
 #include <stdio.h>
 #include <signal.h>
 #include <stdlib.h>
+#include <stdint.h>
 #include "misc.hh"
 #include "mtasker.hh"
 #include <utility>
@@ -135,7 +136,7 @@ std::unique_ptr<NegCache> g_negCache;
 thread_local std::unique_ptr<RecursorPacketCache> t_packetCache;
 thread_local FDMultiplexer* t_fdm{nullptr};
 thread_local std::unique_ptr<addrringbuf_t> t_remotes, t_servfailremotes, t_largeanswerremotes, t_bogusremotes;
-thread_local std::unique_ptr<boost::circular_buffer<pair<DNSName, uint16_t> > > t_queryring, t_servfailqueryring, t_bogusqueryring;
+thread_local std::unique_ptr<boost::circular_buffer<pair<DNSName, uint16_t>>> t_queryring, t_servfailqueryring, t_bogusqueryring;
 thread_local std::shared_ptr<NetmaskGroup> t_allowFrom;
 #ifdef NOD_ENABLED
 thread_local std::shared_ptr<nod::NODDB> t_nodDBp;
@@ -448,7 +449,7 @@ static void handleGenUDPQueryResponse(int fd, FDMultiplexer::funcparam_t& var)
   ComboAddress fromaddr;
   socklen_t addrlen=sizeof(fromaddr);
 
-  ssize_t ret=recvfrom(fd, resp, sizeof(resp), 0, (sockaddr *)&fromaddr, &addrlen);
+  ssize_t ret=recvfrom(fd, resp, sizeof(resp), 0, (struct sockaddr *)&fromaddr, &addrlen);
   if (fromaddr != pident.remote) {
     g_log<<Logger::Notice<<"Response received from the wrong remote host ("<<fromaddr.toStringWithPort()<<" instead of "<<pident.remote.toStringWithPort()<<"), discarding"<<endl;
 
@@ -781,7 +782,8 @@ static void terminateTCPConnection(int fd)
   }
 }
 
-static bool sendResponseOverTCP(const std::unique_ptr<DNSComboWriter>& dc, const std::vector<uint8_t>& packet)
+template <class T>
+static bool sendResponseOverTCP(const std::unique_ptr<DNSComboWriter>& dc, const T& packet)
 {
   char buf[2];
   buf[0] = packet.size() / 256;
@@ -918,6 +920,54 @@ static void protobufLogResponse(pdns::ProtoZero::RecMessage& message)
   }
 }
 
+static void protobufLogResponse(const struct dnsheader* dh,
+                                const RecursorPacketCache::OptPBData& pbData, const struct timeval& tv,
+                                bool tcp, const ComboAddress& source, const ComboAddress& destination,
+                                const EDNSSubnetOpts& ednssubnet,
+                                const boost::uuids::uuid& uniqueId, const string& requestorId, const string& deviceId,
+                                const string& deviceName) 
+{
+  pdns::ProtoZero::RecMessage pbMessage(pbData ? pbData->d_message : "", pbData ? pbData->d_response : "", 64, 10); // The extra bytes we are going to add
+  if (pbData) {
+    // We take the inmutable string from the cache and are appending a few values
+  } else {
+    pbMessage.setType(pdns::ProtoZero::Message::MessageType::DNSResponseType);
+    pbMessage.setServerIdentity(SyncRes::s_serverID);
+  }
+
+  // In response part
+  if (g_useKernelTimestamp && tv.tv_sec) {
+    pbMessage.setQueryTime(tv.tv_sec, tv.tv_usec);
+  }
+  else {
+    pbMessage.setQueryTime(g_now.tv_sec, g_now.tv_usec);
+  }
+
+  auto luaconfsLocal = g_luaconfs.getLocal(); 
+  // In message part
+  Netmask requestorNM(source, source.sin4.sin_family == AF_INET ? luaconfsLocal->protobufMaskV4 : luaconfsLocal->protobufMaskV6);
+  ComboAddress requestor = requestorNM.getMaskedNetwork();
+  pbMessage.setMessageIdentity(uniqueId);
+  pbMessage.setFrom(requestor);
+  pbMessage.setTo(destination);
+  pbMessage.setSocketProtocol(tcp);
+  pbMessage.setId(dh->id);
+
+  pbMessage.setTime();
+  pbMessage.setEDNSSubnet(ednssubnet.source, ednssubnet.source.isIPv4() ? luaconfsLocal->protobufMaskV4 : luaconfsLocal->protobufMaskV6);
+  pbMessage.setRequestorId(requestorId);
+  pbMessage.setDeviceId(deviceId);
+  pbMessage.setDeviceName(deviceName);
+  pbMessage.setFromPort(source.getPort());
+  pbMessage.setToPort(destination.getPort());
+#ifdef NOD_ENABLED
+  if (g_nodEnabled) {
+    pbMessage.setNewlyObservedDomain(false);
+  }
+#endif
+  protobufLogResponse(pbMessage);
+}
+
 /**
  * Chases the CNAME provided by the PolicyCustom RPZ policy.
  *
@@ -2007,6 +2057,18 @@ static void startDoResolve(void *p)
       }
     }
 
+    if(variableAnswer || sr.wasVariable()) {
+      g_stats.variableResponses++;
+    }
+    if(!SyncRes::s_nopacketcache && !variableAnswer && !sr.wasVariable() ) {
+      t_packetCache->insertResponsePacket(dc->d_tag, dc->d_qhash, std::move(dc->d_query), dc->d_mdp.d_qname, dc->d_mdp.d_qtype, dc->d_mdp.d_qclass,
+                                          string((const char*)&*packet.begin(), packet.size()),
+                                          g_now.tv_sec,
+                                          pw.getHeader()->rcode == RCode::ServFail ? SyncRes::s_packetcacheservfailttl :
+                                          min(minTTL,SyncRes::s_packetcachettl),
+                                          dq.validationState,
+                                          std::move(pbDataForCache), dc->d_tcp);
+    }
     if(!dc->d_tcp) {
       struct msghdr msgh;
       struct iovec iov;
@@ -2023,18 +2085,6 @@ static void startDoResolve(void *p)
               << strerror(sendErr) << endl;
       }
 
-      if(variableAnswer || sr.wasVariable()) {
-        g_stats.variableResponses++;
-      }
-      if(!SyncRes::s_nopacketcache && !variableAnswer && !sr.wasVariable() ) {
-        t_packetCache->insertResponsePacket(dc->d_tag, dc->d_qhash, std::move(dc->d_query), dc->d_mdp.d_qname, dc->d_mdp.d_qtype, dc->d_mdp.d_qclass,
-                                            string((const char*)&*packet.begin(), packet.size()),
-                                            g_now.tv_sec,
-                                            pw.getHeader()->rcode == RCode::ServFail ? SyncRes::s_packetcacheservfailttl :
-                                            min(minTTL,SyncRes::s_packetcachettl),
-                                            dq.validationState,
-                                            std::move(pbDataForCache));
-      }
     }
     else {
       bool hadError = sendResponseOverTCP(dc, packet);
@@ -2300,6 +2350,44 @@ static bool handleTCPReadResult(int fd, ssize_t bytes)
   return true;
 }
 
+static bool checkForCacheHit(bool qnameParsed, unsigned int tag, const string& data,
+                             DNSName& qname, uint16_t& qtype, uint16_t& qclass,
+                             const struct timeval& now,
+                             string& response, uint32_t& age, vState& valState, uint32_t& qhash,
+                             RecursorPacketCache::OptPBData& pbData, bool tcp, const ComboAddress& source) 
+{
+  bool cacheHit = false;
+
+  if (qnameParsed) {
+    cacheHit = !SyncRes::s_nopacketcache && t_packetCache->getResponsePacket(tag, data, qname, qtype, qclass, now.tv_sec, &response, &age, &valState, &qhash, &pbData, tcp);
+  } else {
+    cacheHit = !SyncRes::s_nopacketcache && t_packetCache->getResponsePacket(tag, data, qname, &qtype, &qclass, now.tv_sec, &response, &age, &valState, &qhash, &pbData, tcp);
+  }
+
+  if (cacheHit) {
+    if (vStateIsBogus(valState)) {
+      if (t_bogusremotes) {
+        t_bogusremotes->push_back(source);
+      }
+      if (t_bogusqueryring) {
+        t_bogusqueryring->push_back(make_pair(qname, qtype));
+      }
+    }
+
+    g_stats.packetCacheHits++;
+    SyncRes::s_queries++;
+    ageDNSPacket(response, age);
+    if (response.length() >= sizeof(struct dnsheader)) {
+      struct dnsheader tmpdh;
+      memcpy(&tmpdh, response.data(), sizeof(tmpdh)); // XXX Only needed if response.data() isn't aligned
+      updateResponseStats(tmpdh.rcode, source, response.length(), 0, 0);
+    }
+    g_stats.avgLatencyUsec = (1.0 - 1.0 / g_latencyStatSize) * g_stats.avgLatencyUsec + 0.0; // we assume 0 usec
+    g_stats.avgLatencyOursUsec = (1.0 - 1.0 / g_latencyStatSize) * g_stats.avgLatencyOursUsec + 0.0; // we assume 0 usec
+  }
+  return cacheHit;
+}
+
 static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
 {
   shared_ptr<TCPConnection> conn=boost::any_cast<shared_ptr<TCPConnection> >(var);
@@ -2444,7 +2532,7 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
       dest.reset();
       dest.sin4.sin_family = conn->d_remote.sin4.sin_family;
       socklen_t len = dest.getSocklen();
-      getsockname(conn->getFD(), (sockaddr*)&dest, &len); // if this fails, we're ok with it
+      getsockname(conn->getFD(), (struct sockaddr*)&dest, &len); // if this fails, we're ok with it
       dc->setLocal(dest);
       dc->setDestination(conn->d_destination);
       /* we can't move this if we want to be able to access the values in
@@ -2459,6 +2547,7 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
       string deviceId;
       string deviceName;
       bool logQuery = false;
+      bool qnameParsed = false;
 
       auto luaconfsLocal = g_luaconfs.getLocal();
       if (checkProtobufExport(luaconfsLocal)) {
@@ -2481,6 +2570,7 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
           getQNameAndSubnet(conn->data, &qname, &qtype, &qclass,
                             dc->d_ecsFound, &dc->d_ednssubnet, g_gettagNeedsEDNSOptions ? &ednsOptions : nullptr,
                             xpfFound, needXPF ? &dc->d_source : nullptr, needXPF ? &dc->d_destination : nullptr);
+          qnameParsed = true;
 
           if(t_pdl) {
             try {
@@ -2509,7 +2599,7 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
 
       const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(&conn->data[0]);
 
-      if(t_protobufServers || t_outgoingProtobufServers) {
+      if (t_protobufServers || t_outgoingProtobufServers) {
         dc->d_requestorId = requestorId;
         dc->d_deviceId = deviceId;
         dc->d_deviceName = deviceName;
@@ -2568,16 +2658,74 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
       else {
         ++g_stats.qcounter;
         ++g_stats.tcpqcounter;
-        ++conn->d_requestsInFlight;
-        if (conn->d_requestsInFlight >= TCPConnection::s_maxInFlight) {
-          t_fdm->removeReadFD(fd); // should no longer awake ourselves when there is data to read
+
+        string response;
+        uint32_t age;
+        vState valState;
+        uint32_t qhash = 0;
+        RecursorPacketCache::OptPBData pbData{boost::none};
+
+        bool cacheHit = checkForCacheHit(qnameParsed, dc->d_tag, conn->data, qname, qtype, qclass, g_now, response, age, valState, qhash, pbData, true, dc->d_source);
+        dc->d_qhash=qhash;
+
+        if (cacheHit) {
+          if (t_protobufServers && dc->d_logResponse && !(luaconfsLocal->protobufExportConfig.taggedOnly && pbData && !pbData->d_tagged)) {
+            struct timeval tv{0, 0};
+            protobufLogResponse(dh, pbData, tv, true, dc->d_source, dc->d_destination, dc->d_ednssubnet, dc->d_uuid, dc->d_requestorId, dc->d_deviceId, dc->d_deviceName);
+          }
+
+          if (!g_quiet) {
+            g_log<<Logger::Notice<<t_id<< " TCP question answered from packet cache tag="<<dc->d_tag<<" from "<<dc->d_source.toStringWithPort()<<(dc->d_source != dc->d_remote ? " (via "+dc->d_remote.toStringWithPort()+")" : "")<<endl;
+          }
+
+          bool hadError = sendResponseOverTCP(dc, response);
+
+          // In the code below, we try to remove the fd from the set, but
+          // we don't know if another mthread already did the remove, so we can get a
+          // "Tried to remove unlisted fd" exception.  Not that an inflight < limit test
+          // will not work since we do not know if the other mthread got an error or not.
+          if (hadError) {
+            terminateTCPConnection(dc->d_socket);
+            dc->d_socket = -1;
+          }
+          else {
+            dc->d_tcpConnection->queriesCount++;
+            if (g_tcpMaxQueriesPerConn && dc->d_tcpConnection->queriesCount >= g_tcpMaxQueriesPerConn) {
+              try {
+                t_fdm->removeReadFD(dc->d_socket);
+              }
+              catch (FDMultiplexerException &) {
+              }
+              dc->d_socket = -1;
+            }
+            else {
+              Utility::gettimeofday(&g_now, nullptr); // needs to be updated
+              struct timeval ttd = g_now;
+              // fd might have been removed by read error code, or a read timeout, so expect an exception
+              try {
+                t_fdm->setReadTTD(dc->d_socket, ttd, g_tcpTimeout);
+              }
+              catch (const FDMultiplexerException &) {
+                // but if the FD was removed because of a timeout while we were sending a response,
+                // we need to re-arm it. If it was an error it will error again.
+                ttd.tv_sec += g_tcpTimeout;
+                t_fdm->addReadFD(dc->d_socket, handleRunningTCPQuestion, dc->d_tcpConnection, &ttd);
+              }
+            }
+          }
+          return;
         } else {
-          Utility::gettimeofday(&g_now, 0); // needed?
-          struct timeval ttd = g_now;
-          t_fdm->setReadTTD(fd, ttd, g_tcpTimeout);
+          ++conn->d_requestsInFlight;
+          if (conn->d_requestsInFlight >= TCPConnection::s_maxInFlight) {
+            t_fdm->removeReadFD(fd); // should no longer awake ourselves when there is data to read
+          } else {
+            Utility::gettimeofday(&g_now, 0); // needed?
+            struct timeval ttd = g_now;
+            t_fdm->setReadTTD(fd, ttd, g_tcpTimeout);
+          }
+          MT->makeThread(startDoResolve, dc.release()); // deletes dc
+          return;
         }
-        MT->makeThread(startDoResolve, dc.release()); // deletes dc
-        return;
       }
     }
   }
@@ -2773,7 +2921,6 @@ static string* doProcessUDPQuestion(const std::string& question, const ComboAddr
       }
     }
 
-    bool cacheHit = false;
     RecursorPacketCache::OptPBData pbData{boost::none};
 
     if (t_protobufServers) {
@@ -2786,67 +2933,15 @@ static string* doProcessUDPQuestion(const std::string& question, const ComboAddr
        but it means that the hash would not be computed. If some script decides at a later time to mark back the answer
        as cacheable we would cache it with a wrong tag, so better safe than sorry. */
     vState valState;
-    if (qnameParsed) {
-      cacheHit = !SyncRes::s_nopacketcache && t_packetCache->getResponsePacket(ctag, question, qname, qtype, qclass, g_now.tv_sec, &response, &age, &valState, &qhash, &pbData);
-    }
-    else {
-      cacheHit = !SyncRes::s_nopacketcache && t_packetCache->getResponsePacket(ctag, question, qname, &qtype, &qclass, g_now.tv_sec, &response, &age, &valState, &qhash, &pbData);
-    }
-
+    bool cacheHit = checkForCacheHit(qnameParsed, ctag, question, qname, qtype, qclass, g_now, response, age, valState, qhash, pbData, false, source);
     if (cacheHit) {
-      if (vStateIsBogus(valState)) {
-        if(t_bogusremotes)
-          t_bogusremotes->push_back(source);
-        if(t_bogusqueryring)
-          t_bogusqueryring->push_back(make_pair(qname, qtype));
+      if (t_protobufServers && logResponse && !(luaconfsLocal->protobufExportConfig.taggedOnly && pbData && !pbData->d_tagged)) {
+        protobufLogResponse(dh, pbData, tv, false, source, destination, ednssubnet, uniqueId, requestorId, deviceId, deviceName);
       }
 
-      if (t_protobufServers && logResponse && !(luaconfsLocal->protobufExportConfig.taggedOnly && pbData && !pbData->d_tagged)) { // XXX
-        pdns::ProtoZero::RecMessage pbMessage(pbData ? pbData->d_message : "", pbData ? pbData->d_response : "", 64, 10); // The extra bytes we are going to add
-        if (pbData) {
-          // We take the inmutable string from the cache and are appending a few values
-        } else {
-          pbMessage.setType(pdns::ProtoZero::Message::MessageType::DNSResponseType);
-          pbMessage.setServerIdentity(SyncRes::s_serverID);
-        }
-
-        // In response part
-        if (g_useKernelTimestamp && tv.tv_sec) {
-          pbMessage.setQueryTime(tv.tv_sec, tv.tv_usec);
-        }
-        else {
-          pbMessage.setQueryTime(g_now.tv_sec, g_now.tv_usec);
-        }
-        // In message part
-        Netmask requestorNM(source, source.sin4.sin_family == AF_INET ? luaconfsLocal->protobufMaskV4 : luaconfsLocal->protobufMaskV6);
-        ComboAddress requestor = requestorNM.getMaskedNetwork();
-        pbMessage.setMessageIdentity(uniqueId);
-        pbMessage.setFrom(requestor);
-        pbMessage.setTo(destination);
-        pbMessage.setSocketProtocol(false);
-        pbMessage.setId(dh->id);
-
-        pbMessage.setTime();
-        pbMessage.setEDNSSubnet(ednssubnet.source, ednssubnet.source.isIPv4() ? luaconfsLocal->protobufMaskV4 : luaconfsLocal->protobufMaskV6);
-        pbMessage.setRequestorId(requestorId);
-        pbMessage.setDeviceId(deviceId);
-        pbMessage.setDeviceName(deviceName);
-        pbMessage.setFromPort(source.getPort());
-        pbMessage.setToPort(destination.getPort());
-#ifdef NOD_ENABLED
-        if (g_nodEnabled) {
-          pbMessage.setNewlyObservedDomain(false);
-        }
-#endif
-        protobufLogResponse(pbMessage);
-      }
-
-      if(!g_quiet)
+      if (!g_quiet) {
         g_log<<Logger::Notice<<t_id<< " question answered from packet cache tag="<<ctag<<" from "<<source.toStringWithPort()<<(source != fromaddr ? " (via "+fromaddr.toStringWithPort()+")" : "")<<endl;
-
-      g_stats.packetCacheHits++;
-      SyncRes::s_queries++;
-      ageDNSPacket(response, age);
+      }
       struct msghdr msgh;
       struct iovec iov;
       cmsgbuf_aligned cbuf;
@@ -2862,6 +2957,7 @@ static string* doProcessUDPQuestion(const std::string& question, const ComboAddr
               << (source != fromaddr ? " (via " + fromaddr.toStringWithPort() + ")" : "") << " failed with: "
               << strerror(sendErr) << endl;
       }
+#if 0
       if(response.length() >= sizeof(struct dnsheader)) {
         struct dnsheader tmpdh;
         memcpy(&tmpdh, response.c_str(), sizeof(tmpdh));
@@ -2869,6 +2965,7 @@ static string* doProcessUDPQuestion(const std::string& question, const ComboAddr
       }
       g_stats.avgLatencyUsec=(1-1.0/g_latencyStatSize)*g_stats.avgLatencyUsec + 0.0; // we assume 0 usec
       g_stats.avgLatencyOursUsec=(1-1.0/g_latencyStatSize)*g_stats.avgLatencyOursUsec + 0.0; // we assume 0 usec
+#endif
       return 0;
     }
   }
@@ -2986,8 +3083,7 @@ static void handleNewUDPQuestion(int fd, FDMultiplexer::funcparam_t& var)
       else if (len > 512) {
         /* we only allow UDP packets larger than 512 for those with a proxy protocol header */
         g_stats.truncatedDrops++;
-        if (!g_quiet) {
-          g_log<<Logger::Error<<"Ignoring truncated query from "<<fromaddr.toStringWithPort()<<endl;
+        if (!g_quiet) {          g_log<<Logger::Error<<"Ignoring truncated query from "<<fromaddr.toStringWithPort()<<endl;
         }
         return;
       }
@@ -3067,7 +3163,7 @@ static void handleNewUDPQuestion(int fd, FDMultiplexer::funcparam_t& var)
             else {
               dest.sin4.sin_family = fromaddr.sin4.sin_family;
               socklen_t slen = dest.getSocklen();
-              getsockname(fd, (sockaddr*)&dest, &slen); // if this fails, we're ok with it
+              getsockname(fd, (struct sockaddr*)&dest, &slen); // if this fails, we're ok with it
             }
           }
           if (!proxyProto) {
index 047c18311e624703386f15b5c93d5fc5fe08fcd2..570fc830139ad8086cd0ee5a29a4e433d6d82f14 100644 (file)
@@ -42,10 +42,11 @@ int RecursorPacketCache::doWipePacketCache(const DNSName& name, uint16_t qtype,
   return count;
 }
 
-bool RecursorPacketCache::qrMatch(const packetCache_t::index<HashTag>::type::iterator& iter, const std::string& queryPacket, const DNSName& qname, uint16_t qtype, uint16_t qclass)
+bool RecursorPacketCache::qrMatch(const packetCache_t::index<HashTag>::type::iterator& iter, const std::string& queryPacket, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp)
 {
   // this ignores checking on the EDNS subnet flags!
-  if (qname != iter->d_name || iter->d_type != qtype || iter->d_class != qclass) {
+  // XXX OM tcp check is likely not needed, enforced by index
+  if (qname != iter->d_name || iter->d_type != qtype || iter->d_class != qclass || iter->d_tcp != tcp) {
     return false;
   }
 
@@ -53,11 +54,11 @@ bool RecursorPacketCache::qrMatch(const packetCache_t::index<HashTag>::type::ite
   return queryMatches(iter->d_query, queryPacket, qname, optionsToSkip);
 }
 
-bool RecursorPacketCache::checkResponseMatches(std::pair<packetCache_t::index<HashTag>::type::iterator, packetCache_t::index<HashTag>::type::iterator> range, const std::string& queryPacket, const DNSName& qname, uint16_t qtype, uint16_t qclass, time_t now, std::string* responsePacket, uint32_t* age, vState* valState, OptPBData* pbdata)
+bool RecursorPacketCache::checkResponseMatches(std::pair<packetCache_t::index<HashTag>::type::iterator, packetCache_t::index<HashTag>::type::iterator> range, const std::string& queryPacket, const DNSName& qname, uint16_t qtype, uint16_t qclass, time_t now, std::string* responsePacket, uint32_t* age, vState* valState, OptPBData* pbdata, bool tcp)
 {
   for(auto iter = range.first ; iter != range.second ; ++iter) {
     // the possibility is VERY real that we get hits that are not right - birthday paradox
-    if (!qrMatch(iter, queryPacket, qname, qtype, qclass)) {
+    if (!qrMatch(iter, queryPacket, qname, qtype, qclass, tcp)) {
       continue;
     }
 
@@ -111,37 +112,37 @@ bool RecursorPacketCache::getResponsePacket(unsigned int tag, const std::string&
   DNSName qname;
   uint16_t qtype, qclass;
   vState valState;
-  return getResponsePacket(tag, queryPacket, qname, &qtype, &qclass, now, responsePacket, age, &valState, qhash, nullptr);
+  return getResponsePacket(tag, queryPacket, qname, &qtype, &qclass, now, responsePacket, age, &valState, qhash, nullptr, false);
 }
 
 bool RecursorPacketCache::getResponsePacket(unsigned int tag, const std::string& queryPacket, const DNSName& qname, uint16_t qtype, uint16_t qclass, time_t now,
                                             std::string* responsePacket, uint32_t* age, uint32_t* qhash)
 {
   vState valState;
-  return getResponsePacket(tag, queryPacket, qname, qtype, qclass, now, responsePacket, age, &valState, qhash, nullptr);
+  return getResponsePacket(tag, queryPacket, qname, qtype, qclass, now, responsePacket, age, &valState, qhash, nullptr, false);
 }
 
 bool RecursorPacketCache::getResponsePacket(unsigned int tag, const std::string& queryPacket, const DNSName& qname, uint16_t qtype, uint16_t qclass, time_t now,
-                                            std::string* responsePacket, uint32_t* age, vState* valState, uint32_t* qhash, OptPBData* pbdata)
+                                            std::string* responsePacket, uint32_t* age, vState* valState, uint32_t* qhash, OptPBData* pbdata, bool tcp)
 {
   *qhash = canHashPacket(queryPacket, true);
   const auto& idx = d_packetCache.get<HashTag>();
-  auto range = idx.equal_range(tie(tag,*qhash));
+  auto range = idx.equal_range(tie(tag, *qhash, tcp));
 
   if(range.first == range.second) {
     d_misses++;
     return false;
   }
 
-  return checkResponseMatches(range, queryPacket, qname, qtype, qclass, now, responsePacket, age, valState, pbdata);
+  return checkResponseMatches(range, queryPacket, qname, qtype, qclass, now, responsePacket, age, valState, pbdata, tcp);
 }
 
 bool RecursorPacketCache::getResponsePacket(unsigned int tag, const std::string& queryPacket, DNSName& qname, uint16_t* qtype, uint16_t* qclass, time_t now,
-                                            std::string* responsePacket, uint32_t* age, vState* valState, uint32_t* qhash, OptPBData *pbdata)
+                                            std::string* responsePacket, uint32_t* age, vState* valState, uint32_t* qhash, OptPBData *pbdata, bool tcp)
 {
   *qhash = canHashPacket(queryPacket, true);
   const auto& idx = d_packetCache.get<HashTag>();
-  auto range = idx.equal_range(tie(tag,*qhash));
+  auto range = idx.equal_range(tie(tag, *qhash, tcp));
 
   if(range.first == range.second) {
     d_misses++;
@@ -150,18 +151,19 @@ bool RecursorPacketCache::getResponsePacket(unsigned int tag, const std::string&
 
   qname = DNSName(queryPacket.c_str(), queryPacket.length(), sizeof(dnsheader), false, qtype, qclass, 0);
 
-  return checkResponseMatches(range, queryPacket, qname, *qtype, *qclass, now, responsePacket, age, valState, pbdata);
+  return checkResponseMatches(range, queryPacket, qname, *qtype, *qclass, now, responsePacket, age, valState, pbdata, tcp);
 }
 
 
-void RecursorPacketCache::insertResponsePacket(unsigned int tag, uint32_t qhash, std::string&& query, const DNSName& qname, uint16_t qtype, uint16_t qclass, std::string&& responsePacket, time_t now, uint32_t ttl, const vState& valState, OptPBData&& pbdata)
+void RecursorPacketCache::insertResponsePacket(unsigned int tag, uint32_t qhash, std::string&& query, const DNSName& qname, uint16_t qtype, uint16_t qclass, std::string&& responsePacket, time_t now, uint32_t ttl, const vState& valState, OptPBData&& pbdata, bool tcp)
 {
   auto& idx = d_packetCache.get<HashTag>();
-  auto range = idx.equal_range(tie(tag,qhash));
+  auto range = idx.equal_range(tie(tag, qhash, tcp));
   auto iter = range.first;
 
   for( ; iter != range.second ; ++iter)  {
-    if (iter->d_type != qtype || iter->d_class != qclass || iter->d_name != qname) {
+    // XXX OM tcp check not needed?
+    if (iter->d_type != qtype || iter->d_class != qclass || iter->d_tcp != tcp || iter->d_name != qname ) {
       continue;
     }
 
@@ -180,15 +182,14 @@ void RecursorPacketCache::insertResponsePacket(unsigned int tag, uint32_t qhash,
   }
 
   if(iter == range.second) { // nothing to refresh
-    struct Entry e(qname, std::move(responsePacket), std::move(query));
+    struct Entry e(qname, std::move(responsePacket), std::move(query), tcp);
     e.d_qhash = qhash;
     e.d_type = qtype;
     e.d_class = qclass;
-    e.d_ttd = now+ttl;
+    e.d_ttd = now + ttl;
     e.d_creation = now;
     e.d_tag = tag;
     e.d_vstate = valState;
-    e.d_submitted = false;
     if (pbdata) {
       e.d_pbdata = std::move(*pbdata);
     }
@@ -219,23 +220,24 @@ void RecursorPacketCache::doPruneTo(size_t maxCached)
 uint64_t RecursorPacketCache::doDump(int fd)
 {
   auto fp = std::unique_ptr<FILE, int(*)(FILE*)>(fdopen(dup(fd), "w"), fclose);
-  if(!fp) { // dup probably failed
+  if (!fp) { // dup probably failed
     return 0;
   }
+
   fprintf(fp.get(), "; main packet cache dump from thread follows\n;\n");
-  const auto& sidx=d_packetCache.get<1>();
 
-  uint64_t count=0;
-  time_t now=time(0);
-  for(auto i=sidx.cbegin(); i != sidx.cend(); ++i) {
+  const auto& sidx = d_packetCache.get<SequencedTag>();
+  uint64_t count = 0;
+  time_t now = time(nullptr);
+
+  for (const auto& i : sidx) {
     count++;
     try {
-      fprintf(fp.get(), "%s %" PRId64 " %s  ; tag %d\n", i->d_name.toString().c_str(), static_cast<int64_t>(i->d_ttd - now), DNSRecordContent::NumberToType(i->d_type).c_str(), i->d_tag);
+      fprintf(fp.get(), "%s %" PRId64 " %s  ; tag %d %s\n", i.d_name.toString().c_str(), static_cast<int64_t>(i.d_ttd - now), DNSRecordContent::NumberToType(i.d_type).c_str(), i.d_tag, i.d_tcp ? "tcp" : "udp");
     }
     catch(...) {
-      fprintf(fp.get(), "; error printing '%s'\n", i->d_name.empty() ? "EMPTY" : i->d_name.toString().c_str());
+      fprintf(fp.get(), "; error printing '%s'\n", i.d_name.empty() ? "EMPTY" : i.d_name.toString().c_str());
     }
   }
   return count;
-
 }
index 6cc62244e71d6fcaf0a80fc3234b498aa6dfaee0..fa5d7136497ac50410c373a4cc8cbb7828a21228 100644 (file)
@@ -61,10 +61,10 @@ public:
   RecursorPacketCache();
   bool getResponsePacket(unsigned int tag, const std::string& queryPacket, time_t now, std::string* responsePacket, uint32_t* age, uint32_t* qhash);
   bool getResponsePacket(unsigned int tag, const std::string& queryPacket, const DNSName& qname, uint16_t qtype, uint16_t qclass, time_t now, std::string* responsePacket, uint32_t* age, uint32_t* qhash);
-  bool getResponsePacket(unsigned int tag, const std::string& queryPacket, const DNSName& qname, uint16_t qtype, uint16_t qclass, time_t now, std::string* responsePacket, uint32_t* age, vState* valState, uint32_t* qhash, OptPBData* pbdata);
-  bool getResponsePacket(unsigned int tag, const std::string& queryPacket, DNSName& qname, uint16_t *qtype, uint16_t* qclass, time_t now, std::string* responsePacket, uint32_t* age, vState* valState, uint32_t* qhash, OptPBData* pbdata);
+  bool getResponsePacket(unsigned int tag, const std::string& queryPacket, const DNSName& qname, uint16_t qtype, uint16_t qclass, time_t now, std::string* responsePacket, uint32_t* age, vState* valState, uint32_t* qhash, OptPBData* pbdata, bool tcp);
+  bool getResponsePacket(unsigned int tag, const std::string& queryPacket, DNSName& qname, uint16_t *qtype, uint16_t* qclass, time_t now, std::string* responsePacket, uint32_t* age, vState* valState, uint32_t* qhash, OptPBData* pbdata, bool tcp);
   
-  void insertResponsePacket(unsigned int tag, uint32_t qhash, std::string&& query, const DNSName& qname, uint16_t qtype, uint16_t qclass, std::string&& responsePacket, time_t now, uint32_t ttl, const vState& valState, OptPBData&& pbdata);
+  void insertResponsePacket(unsigned int tag, uint32_t qhash, std::string&& query, const DNSName& qname, uint16_t qtype, uint16_t qclass, std::string&& responsePacket, time_t now, uint32_t ttl, const vState& valState, OptPBData&& pbdata, bool tcp);
   void doPruneTo(size_t maxSize=250000);
   uint64_t doDump(int fd);
   int doWipePacketCache(const DNSName& name, uint16_t qtype=0xffff, bool subtree=false);
@@ -79,22 +79,23 @@ private:
   struct NameTag {};
   struct Entry 
   {
-    Entry(const DNSName& qname, std::string&& packet, std::string&& query): d_name(qname), d_packet(std::move(packet)), d_query(std::move(query))
+    Entry(const DNSName& qname, std::string&& packet, std::string&& query, bool tcp): d_name(qname), d_packet(std::move(packet)), d_query(std::move(query)), d_tcp(tcp)
     {
     }
 
     DNSName d_name;
-    mutable std::string d_packet; // "I know what I am doing"
+    mutable std::string d_packet;
     mutable std::string d_query;
     mutable OptPBData d_pbdata;
     mutable time_t d_ttd;
-    mutable time_t d_creation; // so we can 'age' our packets
+    mutable time_t d_creation;  // so we can 'age' our packets
     uint32_t d_qhash;
     uint32_t d_tag;
     uint16_t d_type;
     uint16_t d_class;
     mutable vState d_vstate;
-    mutable bool d_submitted;   // whether this entry has been queued for refetch
+    mutable bool d_submitted{false}; // whether this entry has been queued for refetch
+    bool d_tcp;                      // whether this entry was created from a TCP query
     inline bool operator<(const struct Entry& rhs) const;
 
     time_t getTTD() const
@@ -110,17 +111,23 @@ private:
   struct SequencedTag{};
   typedef multi_index_container<
     Entry,
-    indexed_by  <
-      hashed_non_unique<tag<HashTag>, composite_key<Entry, member<Entry,uint32_t,&Entry::d_tag>, member<Entry,uint32_t,&Entry::d_qhash> > >,
-      sequenced<tag<SequencedTag>> ,
-      ordered_non_unique<tag<NameTag>, member<Entry,DNSName,&Entry::d_name>, CanonDNSNameCompare >
+    indexed_by <
+      hashed_non_unique<tag<HashTag>,
+                        composite_key<Entry,
+                                      member<Entry, uint32_t, &Entry::d_tag>,
+                                      member<Entry, uint32_t, &Entry::d_qhash>,
+                                      member<Entry, bool, &Entry::d_tcp>
+                                      >
+                        >,
+      sequenced<tag<SequencedTag>>,
+      ordered_non_unique<tag<NameTag>, member<Entry,DNSName, &Entry::d_name>, CanonDNSNameCompare>
       >
-  > packetCache_t;
-  
+    > packetCache_t;
+
   packetCache_t d_packetCache;
 
-  static bool qrMatch(const packetCache_t::index<HashTag>::type::iterator& iter, const std::string& queryPacket, const DNSName& qname, uint16_t qtype, uint16_t qclass);
-  bool checkResponseMatches(std::pair<packetCache_t::index<HashTag>::type::iterator, packetCache_t::index<HashTag>::type::iterator> range, const std::string& queryPacket, const DNSName& qname, uint16_t qtype, uint16_t qclass, time_t now, std::string* responsePacket, uint32_t* age, vState* valState, OptPBData* pbdata);
+  static bool qrMatch(const packetCache_t::index<HashTag>::type::iterator& iter, const std::string& queryPacket, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp);
+  bool checkResponseMatches(std::pair<packetCache_t::index<HashTag>::type::iterator, packetCache_t::index<HashTag>::type::iterator> range, const std::string& queryPacket, const DNSName& qname, uint16_t qtype, uint16_t qclass, time_t now, std::string* responsePacket, uint32_t* age, vState* valState, OptPBData* pbdata, bool tcp);
 
 public:
   void preRemoval(const Entry& entry)
index 7285ae07e100af8ad3f267fefbd1343ef206528d..95107b5d61b499e2e29a30099e5df6374ac66ddd 100644 (file)
@@ -45,16 +45,16 @@ BOOST_AUTO_TEST_CASE(test_recPacketCacheSimple) {
   pw.commit();
   string rpacket((const char*)&packet[0], packet.size());
 
-  rpc.insertResponsePacket(tag, qhash, string(qpacket), qname, QType::A, QClass::IN, string(rpacket), time(0), ttd, vState::Indeterminate, boost::none);
+  rpc.insertResponsePacket(tag, qhash, string(qpacket), qname, QType::A, QClass::IN, string(rpacket), time(0), ttd, vState::Indeterminate, boost::none, false);
   BOOST_CHECK_EQUAL(rpc.size(), 1U);
   rpc.doPruneTo(0);
   BOOST_CHECK_EQUAL(rpc.size(), 0U);
-  rpc.insertResponsePacket(tag, qhash, string(qpacket), qname, QType::A, QClass::IN, string(rpacket), time(0), ttd, vState::Indeterminate, boost::none);
+  rpc.insertResponsePacket(tag, qhash, string(qpacket), qname, QType::A, QClass::IN, string(rpacket), time(0), ttd, vState::Indeterminate, boost::none, false);
   BOOST_CHECK_EQUAL(rpc.size(), 1U);
   rpc.doWipePacketCache(qname);
   BOOST_CHECK_EQUAL(rpc.size(), 0U);
 
-  rpc.insertResponsePacket(tag, qhash, string(qpacket), qname, QType::A, QClass::IN, string(rpacket), time(0), ttd, vState::Indeterminate, boost::none);
+  rpc.insertResponsePacket(tag, qhash, string(qpacket), qname, QType::A, QClass::IN, string(rpacket), time(0), ttd, vState::Indeterminate, boost::none, false);
   BOOST_CHECK_EQUAL(rpc.size(), 1U);
   uint32_t qhash2 = 0;
   bool found = rpc.getResponsePacket(tag, qpacket, time(nullptr), &fpacket, &age, &qhash2);
@@ -217,11 +217,11 @@ BOOST_AUTO_TEST_CASE(test_recPacketCache_Tags) {
   BOOST_CHECK(r1packet != r2packet);
 
   /* inserting a response for tag1 */
-  rpc.insertResponsePacket(tag1, qhash, string(qpacket), qname, QType::A, QClass::IN, string(r1packet), time(0), ttd, vState::Indeterminate, boost::none);
+  rpc.insertResponsePacket(tag1, qhash, string(qpacket), qname, QType::A, QClass::IN, string(r1packet), time(0), ttd, vState::Indeterminate, boost::none, false);
   BOOST_CHECK_EQUAL(rpc.size(), 1U);
 
   /* inserting a different response for tag2, should not override the first one */
-  rpc.insertResponsePacket(tag2, qhash, string(qpacket), qname, QType::A, QClass::IN, string(r2packet), time(0), ttd, vState::Indeterminate, boost::none);
+  rpc.insertResponsePacket(tag2, qhash, string(qpacket), qname, QType::A, QClass::IN, string(r2packet), time(0), ttd, vState::Indeterminate, boost::none, false);
   BOOST_CHECK_EQUAL(rpc.size(), 2U);
 
   /* remove all responses from the cache */
@@ -229,10 +229,10 @@ BOOST_AUTO_TEST_CASE(test_recPacketCache_Tags) {
   BOOST_CHECK_EQUAL(rpc.size(), 0U);
 
   /* reinsert both */
-  rpc.insertResponsePacket(tag1, qhash, string(qpacket), qname, QType::A, QClass::IN, string(r1packet), time(0), ttd, vState::Indeterminate, boost::none);
+  rpc.insertResponsePacket(tag1, qhash, string(qpacket), qname, QType::A, QClass::IN, string(r1packet), time(0), ttd, vState::Indeterminate, boost::none, false);
   BOOST_CHECK_EQUAL(rpc.size(), 1U);
 
-  rpc.insertResponsePacket(tag2, qhash, string(qpacket), qname, QType::A, QClass::IN, string(r2packet), time(0), ttd, vState::Indeterminate, boost::none);
+  rpc.insertResponsePacket(tag2, qhash, string(qpacket), qname, QType::A, QClass::IN, string(r2packet), time(0), ttd, vState::Indeterminate, boost::none, false);
   BOOST_CHECK_EQUAL(rpc.size(), 2U);
 
   /* remove the responses by qname, should remove both */
@@ -240,7 +240,7 @@ BOOST_AUTO_TEST_CASE(test_recPacketCache_Tags) {
   BOOST_CHECK_EQUAL(rpc.size(), 0U);
 
   /* insert the response for tag1 */
-  rpc.insertResponsePacket(tag1, qhash, string(qpacket), qname, QType::A, QClass::IN, string(r1packet), time(0), ttd, vState::Indeterminate, boost::none);
+  rpc.insertResponsePacket(tag1, qhash, string(qpacket), qname, QType::A, QClass::IN, string(r1packet), time(0), ttd, vState::Indeterminate, boost::none, false);
   BOOST_CHECK_EQUAL(rpc.size(), 1U);
 
   /* we can retrieve it */
@@ -259,7 +259,7 @@ BOOST_AUTO_TEST_CASE(test_recPacketCache_Tags) {
   BOOST_CHECK_EQUAL(temphash, qhash);
 
   /* adding a response for the second tag */
-  rpc.insertResponsePacket(tag2, qhash, string(qpacket), qname, QType::A, QClass::IN, string(r2packet), time(0), ttd, vState::Indeterminate, boost::none);
+  rpc.insertResponsePacket(tag2, qhash, string(qpacket), qname, QType::A, QClass::IN, string(r2packet), time(0), ttd, vState::Indeterminate, boost::none, false);
   BOOST_CHECK_EQUAL(rpc.size(), 2U);
 
   /* We still get the correct response for the first tag */