]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Add and test support for views in packet cache.
authorMiod Vallat <miod.vallat@powerdns.com>
Thu, 24 Apr 2025 10:19:58 +0000 (12:19 +0200)
committerMiod Vallat <miod.vallat@powerdns.com>
Mon, 26 May 2025 11:49:12 +0000 (13:49 +0200)
The packet cache data buckets are now set up in an unordered map
addressed by the view name.

Doing this also makes sure that, if the network configuration of a view
changes, the cache contents are still valid as long as there is no
change in the zones found in that view.

12 files changed:
pdns/auth-main.cc
pdns/auth-packetcache.cc
pdns/auth-packetcache.hh
pdns/auth-primarycommunicator.cc
pdns/auth-zonecache.cc
pdns/dnspacket.hh
pdns/packethandler.cc
pdns/tcpreceiver.cc
pdns/test-iputils_hh.cc
pdns/test-packetcache_cc.cc
pdns/ueberbackend.cc
pdns/ws-auth.cc

index adaeb3cc8e1686c017030e5fae081add701b9e3a..cfe621282e11f6860aee547631130393619b5e99 100644 (file)
@@ -628,7 +628,12 @@ try {
 
       if (PC.enabled() && (question.d.opcode != Opcode::Notify && question.d.opcode != Opcode::Update) && question.couldBeCached()) {
         start = diff;
-        bool haveSomething = PC.get(question, cached); // does the PacketCache recognize this question?
+        std::string view{};
+        if (g_views) {
+          Netmask netmask(accountremote);
+          view = g_zoneCache.getViewFromNetwork(&netmask);
+        }
+        bool haveSomething = PC.get(question, cached, view); // does the PacketCache recognize this question?
         if (haveSomething) {
           if (logDNSQueries)
             g_log << ": packetcache HIT" << endl;
index 0a4a835c09f1cf1fb99468360b977c2f3b43e8ef..fd1605c8e98c239946a64f495d214eb6ec0f1624 100644 (file)
@@ -30,7 +30,7 @@ extern StatBag S;
 
 const unsigned int AuthPacketCache::s_mincleaninterval, AuthPacketCache::s_maxcleaninterval;
 
-AuthPacketCache::AuthPacketCache(size_t mapsCount): d_maps(mapsCount), d_lastclean(time(nullptr))
+AuthPacketCache::AuthPacketCache(size_t mapsCount): d_mapscount(mapsCount), d_lastclean(time(nullptr))
 {
   S.declare("packetcache-hit", "Number of hits on the packet cache");
   S.declare("packetcache-miss", "Number of misses on the packet cache");
@@ -41,6 +41,26 @@ AuthPacketCache::AuthPacketCache(size_t mapsCount): d_maps(mapsCount), d_lastcle
   d_statnumhit=S.getPointer("packetcache-hit");
   d_statnummiss=S.getPointer("packetcache-miss");
   d_statnumentries=S.getPointer("packetcache-size");
+
+  // Create the MapCombo for the default view
+  std::string defaultview{};
+  createViewMap(defaultview);
+}
+
+// Create the vector<MapCombo> for the given view.
+// Assumes there is no existing data for the view. Callers are expected to
+// know what they are doing.
+std::unordered_map<std::string, std::unique_ptr<vector<AuthPacketCache::MapCombo>>>::iterator AuthPacketCache::createViewMap(const std::string& view)
+{
+  auto iter = d_cache.emplace(view, std::make_unique<vector<MapCombo>>(d_mapscount));
+  auto retval = iter.first;
+  auto* map = retval->second.get();
+  // Note that this reserves more than intended, especially if multiple views
+  // are used.
+  for (auto& shard : *map) {
+    shard.reserve(d_maxEntries / map->size());
+  }
+  return retval;
 }
 
 void AuthPacketCache::MapCombo::reserve(size_t numberOfEntries)
@@ -50,7 +70,7 @@ void AuthPacketCache::MapCombo::reserve(size_t numberOfEntries)
 #endif /* BOOST_VERSION >= 105600 */
 }
 
-bool AuthPacketCache::get(DNSPacket& p, DNSPacket& cached)
+bool AuthPacketCache::get(DNSPacket& pkt, DNSPacket& cached, const std::string& view)
 {
   if(!d_ttl) {
     return false;
@@ -59,21 +79,27 @@ bool AuthPacketCache::get(DNSPacket& p, DNSPacket& cached)
   cleanupIfNeeded();
 
   static const std::unordered_set<uint16_t> optionsToSkip{ EDNSOptionCode::COOKIE};
-  uint32_t hash = canHashPacket(p.getString(), /* don't skip ECS */optionsToSkip);
-  p.setHash(hash);
+  uint32_t hash = canHashPacket(pkt.getString(), /* don't skip ECS */optionsToSkip);
+  pkt.setHash(hash);
 
   string value;
   bool haveSomething;
   time_t now = time(nullptr);
-  auto& mc = getMap(p.qdomain);
+  auto iter = d_cache.find(view);
+  if (iter == d_cache.end()) {
+    // No data for this view yet.
+    (*d_statnummiss)++;
+    return false;
+  }
+  auto& mapcombo = getMap(iter->second, pkt.qdomain);
   {
-    auto map = mc.d_map.try_read_lock();
+    auto map = mapcombo.d_map.try_read_lock();
     if (!map.owns_lock()) {
       S.inc("deferred-packetcache-lookup");
       return false;
     }
 
-    haveSomething = getEntryLocked(*map, p.getString(), hash, p.qdomain, p.qtype.getCode(), p.d_tcp, now, value);
+    haveSomething = AuthPacketCache::getEntryLocked(*map, pkt.getString(), hash, pkt.qdomain, pkt.qtype.getCode(), pkt.d_tcp, now, value);
   }
 
   if (!haveSomething) {
@@ -86,9 +112,9 @@ bool AuthPacketCache::get(DNSPacket& p, DNSPacket& cached)
   }
 
   (*d_statnumhit)++;
-  cached.spoofQuestion(p); // for correct case
-  cached.qdomain = p.qdomain;
-  cached.qtype = p.qtype;
+  cached.spoofQuestion(pkt); // for correct case
+  cached.qdomain = pkt.qdomain;
+  cached.qtype = pkt.qtype;
 
   return true;
 }
@@ -99,7 +125,7 @@ bool AuthPacketCache::entryMatches(cmap_t::index<HashTag>::type::iterator& iter,
   return iter->tcp == tcp && iter->qtype == qtype && iter->qname == qname && queryMatches(iter->query, query, qname, skippedEDNSTypes);
 }
 
-void AuthPacketCache::insert(DNSPacket& q, DNSPacket& r, unsigned int maxTTL)
+void AuthPacketCache::insert(DNSPacket& query, DNSPacket& response, unsigned int maxTTL, const std::string& view)
 {
   if(!d_ttl) {
     return;
@@ -107,31 +133,37 @@ void AuthPacketCache::insert(DNSPacket& q, DNSPacket& r, unsigned int maxTTL)
 
   cleanupIfNeeded();
 
-  if (ntohs(q.d.qdcount) != 1) {
+  if (ntohs(query.d.qdcount) != 1) {
     return; // do not try to cache packets with multiple questions
   }
 
-  if (q.qclass != QClass::IN) // we only cache the INternet
+  if (query.qclass != QClass::IN) { // we only cache the INternet
     return;
+  }
 
   uint32_t ourttl = std::min(d_ttl, maxTTL);
   if (ourttl == 0) {
     return;
-  }  
+  }
 
-  uint32_t hash = q.getHash();
+  uint32_t hash = query.getHash();
   time_t now = time(nullptr);
   CacheEntry entry;
   entry.hash = hash;
   entry.created = now;
   entry.ttd = now + ourttl;
-  entry.qname = q.qdomain;
-  entry.qtype = q.qtype.getCode();
-  entry.value = r.getString();
-  entry.tcp = r.d_tcp;
-  entry.query = q.getString();
-  
-  auto& mc = getMap(entry.qname);
+  entry.qname = query.qdomain;
+  entry.qtype = query.qtype.getCode();
+  entry.value = response.getString();
+  entry.tcp = response.d_tcp;
+  entry.query = query.getString();
+
+  auto iter = d_cache.find(view);
+  if (iter == d_cache.end()) {
+    // No data for this view yet, create it.
+    iter = createViewMap(view);
+  }
+  auto& mc = getMap(iter->second, entry.qname); // NOLINT(readability-identifier-length)
   {
     auto map = mc.d_map.try_write_lock();
     if (!map.owns_lock()) {
@@ -141,17 +173,17 @@ void AuthPacketCache::insert(DNSPacket& q, DNSPacket& r, unsigned int maxTTL)
 
     auto& idx = map->get<HashTag>();
     auto range = idx.equal_range(hash);
-    auto iter = range.first;
+    auto iter2 = range.first;
 
-    for( ; iter != range.second ; ++iter)  {
-      if (!entryMatches(iter, entry.query, entry.qname, entry.qtype, entry.tcp)) {
+    for( ; iter2 != range.second ; ++iter2)  {
+      if (!entryMatches(iter2, entry.query, entry.qname, entry.qtype, entry.tcp)) {
         continue;
       }
 
-      moveCacheItemToBack<SequencedTag>(*map, iter);
-      iter->value = entry.value;
-      iter->ttd = now + ourttl;
-      iter->created = now;
+      moveCacheItemToBack<SequencedTag>(*map, iter2);
+      iter2->value = entry.value;
+      iter2->ttd = now + ourttl;
+      iter2->created = now;
       return;
     }
 
@@ -171,7 +203,7 @@ void AuthPacketCache::insert(DNSPacket& q, DNSPacket& r, unsigned int maxTTL)
 
 bool AuthPacketCache::getEntryLocked(const cmap_t& map, const std::string& query, uint32_t hash, const DNSName &qname, uint16_t qtype, bool tcp, time_t now, string& value)
 {
-  auto& idx = map.get<HashTag>();
+  const auto& idx = map.get<HashTag>();
   auto range = idx.equal_range(hash);
 
   for(auto iter = range.first; iter != range.second ; ++iter)  {
@@ -182,7 +214,6 @@ bool AuthPacketCache::getEntryLocked(const cmap_t& map, const std::string& query
     if (!entryMatches(iter, query, qname, qtype, tcp)) {
       continue;
     }
-
     value = iter->value;
     return true;
   }
@@ -199,13 +230,36 @@ uint64_t AuthPacketCache::purge()
 
   d_statnumentries->store(0);
 
-  return purgeLockedCollectionsVector(d_maps);
+  uint64_t delcount = 0;
+  for (auto& iter : d_cache) {
+    auto* map = iter.second.get();
+    delcount += purgeLockedCollectionsVector(*map);
+  }
+  return delcount;
 }
 
 uint64_t AuthPacketCache::purgeExact(const DNSName& qname)
 {
-  auto& mc = getMap(qname);
-  uint64_t delcount = purgeExactLockedCollection<NameTag>(mc, qname);
+  uint64_t delcount = 0;
+
+  for (auto& iter : d_cache) {
+    auto& mc = getMap(iter.second, qname); // NOLINT(readability-identifier-length)
+    delcount += purgeExactLockedCollection<NameTag>(mc, qname);
+  }
+
+  *d_statnumentries -= delcount;
+
+  return delcount;
+}
+
+uint64_t AuthPacketCache::purgeExact(const std::string& view, const DNSName& qname)
+{
+  uint64_t delcount = 0;
+
+  if (auto iter = d_cache.find(view); iter != d_cache.end()) {
+    auto& mc = getMap(iter->second, qname); // NOLINT(readability-identifier-length)
+    delcount += purgeExactLockedCollection<NameTag>(mc, qname);
+  }
 
   *d_statnumentries -= delcount;
 
@@ -222,7 +276,10 @@ uint64_t AuthPacketCache::purge(const string &match)
   uint64_t delcount = 0;
 
   if(boost::ends_with(match, "$")) {
-    delcount = purgeLockedCollectionsVector<NameTag>(d_maps, match);
+    for (auto& iter : d_cache) {
+      auto* map = iter.second.get();
+      delcount += purgeLockedCollectionsVector<NameTag>(*map, match);
+    }
     *d_statnumentries -= delcount;
   }
   else {
@@ -231,10 +288,14 @@ uint64_t AuthPacketCache::purge(const string &match)
 
   return delcount;
 }
-                          
+
 void AuthPacketCache::cleanup()
 {
-  uint64_t totErased = pruneLockedCollectionsVector<SequencedTag>(d_maps);
+  uint64_t totErased = 0;
+  for (auto& iter : d_cache) {
+    auto* map = iter.second.get();
+    totErased += pruneLockedCollectionsVector<SequencedTag>(*map);
+  }
   *d_statnumentries -= totErased;
 
   DLOG(g_log<<"Done with cache clean, cacheSize: "<<(*d_statnumentries)<<", totErased"<<totErased<<endl);
index 0b22e6282556d15f63b10ccf4cff23b25b824e07..8157ae26d9f2deaf9272481a4cddfc9fd0388c5d 100644 (file)
@@ -22,6 +22,7 @@
 #pragma once
 #include <string>
 #include <map>
+#include <unordered_map>
 #include "dns.hh"
 #include <boost/version.hpp>
 #include "namespaces.hh"
@@ -40,7 +41,10 @@ using namespace ::boost::multi_index;
 /** This class performs 'whole packet caching'. Feed it a question packet and it will
     try to find an answer. If you have an answer, insert it to have it cached for later use. 
     Take care not to replace existing cache entries. While this works, it is wasteful. Only
-    insert packets that where not found by get()
+    insert packets that were not found by get()
+
+    Caches are indexed by views. When views are not used, all the data in the
+    cache is associated to the empty string "" default view.
 
     Locking! 
 
@@ -53,29 +57,34 @@ class AuthPacketCache : public PacketCache
 public:
   AuthPacketCache(size_t mapsCount=1024);
 
-  void insert(DNSPacket& q, DNSPacket& r, uint32_t maxTTL);  //!< We copy the contents of *p into our cache. Do not needlessly call this to insert questions already in the cache as it wastes resources
+  void insert(DNSPacket& query, DNSPacket& response, uint32_t maxTTL, const std::string& view);  //!< We copy the contents of *p into our cache. Do not needlessly call this to insert questions already in the cache as it wastes resources
 
-  bool get(DNSPacket& p, DNSPacket& q); //!< You need to spoof in the right ID with the DNSPacket.spoofID() method.
+  bool get(DNSPacket& pkt, DNSPacket& cached, const std::string& view = ""); //!< You need to spoof in the right ID with the DNSPacket.spoofID() method.
 
   void cleanup(); //!< force the cache to preen itself from expired packets
   uint64_t purge();
   uint64_t purge(const std::string& match); // could be $ terminated. Is not a dnsname!
   uint64_t purgeExact(const DNSName& qname); // no wildcard matching here
+  uint64_t purgeExact(const std::string& view, const DNSName& qname); // same as above, but in the given view
 
   uint64_t size() const { return *d_statnumentries; };
 
   void setMaxEntries(uint64_t maxEntries) 
   {
     d_maxEntries = maxEntries;
-    for (auto& shard : d_maps) {
-      shard.reserve(maxEntries / d_maps.size());
+    for (auto& iter : d_cache) {
+      auto* map = iter.second.get();
+      
+      for (auto& shard : *map) {
+        shard.reserve(maxEntries / map->size());
+      }
     }
   }
   void setTTL(uint32_t ttl)
   {
     d_ttl = ttl;
   }
-  bool enabled()
+  bool enabled() const
   {
     return (d_ttl > 0);
   }
@@ -120,14 +129,15 @@ private:
     SharedLockGuarded<cmap_t> d_map;
   };
 
-  vector<MapCombo> d_maps;
-  MapCombo& getMap(const DNSName& name)
+  std::unordered_map<std::string, std::unique_ptr<vector<MapCombo>>> d_cache;
+  static MapCombo& getMap(std::unique_ptr<vector<MapCombo>>& map, const DNSName& name)
   {
-    return d_maps[name.hash() % d_maps.size()];
+    return (*map)[name.hash() % map->size()];
   }
 
+  std::unordered_map<std::string, std::unique_ptr<vector<MapCombo>>>::iterator createViewMap(const std::string& view);
   static bool entryMatches(cmap_t::index<HashTag>::type::iterator& iter, const std::string& query, const DNSName& qname, uint16_t qtype, bool tcp);
-  bool getEntryLocked(const cmap_t& map, const std::string& query, uint32_t hash, const DNSName &qname, uint16_t qtype, bool tcp, time_t now, string& entry);
+  static bool getEntryLocked(const cmap_t& map, const std::string& query, uint32_t hash, const DNSName &qname, uint16_t qtype, bool tcp, time_t now, string& value);
   void cleanupIfNeeded();
 
   AtomicCounter d_ops{0};
@@ -136,6 +146,7 @@ private:
   AtomicCounter *d_statnumentries;
 
   uint64_t d_maxEntries{0};
+  size_t d_mapscount;
   time_t d_lastclean; // doesn't need to be atomic
   unsigned long d_nextclean{4096};
   unsigned int d_cleaninterval{4096};
index 6e35c3a0076487659fb67d0bcd52a767f39edbe4..b45e6e06dbefb0ba55ec8fa0264a78caab4f43dd 100644 (file)
@@ -202,6 +202,8 @@ void CommunicatorClass::primaryUpdateCheck(PacketHandler* P)
   }
 
   for (auto& di : cmdomains) {
+    // VIEWS TODO: if this zone has a variant, try to figure out which
+    // views contain it, and purge these views only.
     purgeAuthCachesExact(di.zone.operator const DNSName&());
     g_zoneCache.add(di.zone, di.id);
     queueNotifyDomain(di, B);
index c20fc7ddf01734fb5fa6d7f8561f2e584126f236..ff4c0948d96e632e76fe947b4b22b1dcbd1dd00e 100644 (file)
@@ -88,15 +88,6 @@ std::string AuthZoneCache::getViewFromNetwork(Netmask* net)
     // this handles the "empty" case, but might hide other errors
   }
 
-  // If this network doesn't match a view, then we want to clear the netmask
-  // information, as our caller might submit it to the packet cache and there
-  // is no reason to narrow caching for views-agnostic queries.
-  // TODO: no longer needed once packet cache indexes on views rather than
-  // netmasks.
-  if (view.empty()) {
-    *net = Netmask();
-  }
-
   return view;
 }
 
index 13146fde4b5c2dc0d35fa3b522bbcf86a9d19e34..9e16b432c70b3350c3dfd0bbaf740fcfb4fe2db7 100644 (file)
@@ -177,6 +177,9 @@ public:
   void cleanupGSS(int rcode);
 #endif
 
+  Netmask d_span; // network matching this packet, when views are used
+  std::string d_view; // view matching this packet, when views are used
+
 private:
   void pasteQ(const char *question, int length); //!< set the question of this packet, useful for crafting replies
 
index 4c76288e25a805aed631219bf2088da41f1e7d04..a35003bda33a30f08bf308c74a0c3973b4011197 100644 (file)
@@ -1937,7 +1937,7 @@ std::unique_ptr<DNSPacket> PacketHandler::opcodeQuery(DNSPacket& pkt, bool noCac
     }
 
     if (PC.enabled() && !state.noCache && pkt.couldBeCached()) {
-      PC.insert(pkt, *state.r, state.r->getMinTTL()); // in the packet cache
+      PC.insert(pkt, *state.r, state.r->getMinTTL(), pkt.d_view); // in the packet cache
     }
   }
 
index 454ef41e6af75a5cf70d8be1ab0c74759a8266f4..7477665b585968b2922283b1b4e5c0b039fbbd9f 100644 (file)
@@ -388,18 +388,26 @@ void TCPNameserver::doConnection(int fd)
         "', do = " <<packet->d_dnssecOk <<", bufsize = "<< packet->getMaxReplyLen();
       }
 
-      if(PC.enabled()) {
-        if(packet->couldBeCached() && PC.get(*packet, *cached)) { // short circuit - does the PacketCache recognize this question?
-          if(logDNSQueries)
-            g_log<<": packetcache HIT"<<endl;
-          cached->setRemote(&packet->d_remote);
-          cached->d_inner_remote = packet->d_inner_remote;
-          cached->d.id=packet->d.id;
-          cached->d.rd=packet->d.rd; // copy in recursion desired bit
-          cached->commitD(); // commit d to the packet                        inlined
-
-          sendPacket(cached, fd); // presigned, don't do it again
-          continue;
+      if (PC.enabled()) {
+        if (packet->couldBeCached()) {
+          std::string view{};
+          if (g_views) {
+            Netmask netmask(packet->d_remote);
+            view = g_zoneCache.getViewFromNetwork(&netmask);
+          }
+          if (PC.get(*packet, *cached, view)) { // short circuit - does the PacketCache recognize this question?
+            if(logDNSQueries) {
+              g_log<<": packetcache HIT"<<endl;
+           }
+            cached->setRemote(&packet->d_remote);
+            cached->d_inner_remote = packet->d_inner_remote;
+            cached->d.id=packet->d.id;
+            cached->d.rd=packet->d.rd; // copy in recursion desired bit
+            cached->commitD(); // commit d to the packet                        inlined
+
+            sendPacket(cached, fd); // presigned, don't do it again
+            continue;
+          }
         }
         if(logDNSQueries)
             g_log<<": packetcache MISS"<<endl;
index b55259ba9ea1041a0533ae5943c4da18744ad13e..5d48c74134997180f16e09a8e8b599370fa72e3e 100644 (file)
@@ -299,8 +299,18 @@ BOOST_AUTO_TEST_CASE(test_Netmask) {
   BOOST_CHECK(nm1921 < nm1922);
   BOOST_CHECK(nm1922 > nm1921);
 
+  Netmask outer("20.25.0.0/16");
+  Netmask inner("20.25.4.0/24");
+  Netmask disjoint("20.24.0.0/16");
+  BOOST_CHECK(outer.match(inner.getNetwork()));
+  BOOST_CHECK(!inner.match(outer.getNetwork()));
+  BOOST_CHECK(!outer.match(disjoint.getNetwork()));
+  BOOST_CHECK(!inner.match(disjoint.getNetwork()));
+  BOOST_CHECK(!disjoint.match(inner.getNetwork()));
+  BOOST_CHECK(!disjoint.match(outer.getNetwork()));
+
   /* An empty Netmask should be larger than
-     every others. */
+     every other. */
   Netmask empty = Netmask();
   Netmask full("255.255.255.255/32");
   BOOST_CHECK(empty > all);
index 4cba7e8b8474fea7db7b3af89af1a10daf159352..021479e68aa25ff3b69bb465a76d8ca0223cc2f6 100644 (file)
@@ -13,6 +13,9 @@
 #include "statbag.hh"
 #include "auth-packetcache.hh"
 #include "auth-querycache.hh"
+#ifdef PDNS_AUTH
+#include "auth-zonecache.hh"
+#endif
 #include "arguments.hh"
 #include <utility>
 #include <thread>
@@ -41,7 +44,7 @@ BOOST_AUTO_TEST_CASE(test_AuthQueryCacheSimple) {
 
       QC.insert(a, QType(QType::A), vector<DNSZoneRecord>(records), 3600, 1);
       if(!QC.purge(a.toString()))
-       BOOST_FAIL("Could not remove entry we just added to the query cache!");
+        BOOST_FAIL("Could not remove entry we just added to the query cache!");
       QC.insert(a, QType(QType::A), vector<DNSZoneRecord>(records), 3600, 1);
     }
 
@@ -60,7 +63,7 @@ BOOST_AUTO_TEST_CASE(test_AuthQueryCacheSimple) {
     int64_t expected=counter-delcounter;
     for(; delcounter < counter; ++delcounter) {
       if(QC.getEntry(DNSName("hello ")+DNSName(std::to_string(delcounter)), QType(QType::A), entry, 1)) {
-       matches++;
+        matches++;
       }
     }
     BOOST_CHECK_EQUAL(matches, expected);
@@ -171,7 +174,7 @@ try
     q.setHash(g_PC->canHashPacket(q.getString()));
 
     const unsigned int maxTTL = 3600;
-    g_PC->insert(q, r, maxTTL);
+    g_PC->insert(q, r, maxTTL, "");
   }
 }
  catch(PDNSException& e) {
@@ -393,7 +396,7 @@ BOOST_AUTO_TEST_CASE(test_AuthPacketCache) {
     /* this call is required so the correct hash is set into q->d_hash */
     BOOST_CHECK_EQUAL(PC.get(q, r2), false);
 
-    PC.insert(q, r, 3600);
+    PC.insert(q, r, 3600, "");
     BOOST_CHECK_EQUAL(PC.size(), 1U);
 
     BOOST_CHECK_EQUAL(PC.get(q, r2), true);
@@ -406,7 +409,7 @@ BOOST_AUTO_TEST_CASE(test_AuthPacketCache) {
     /* with EDNS, should not match */
     BOOST_CHECK_EQUAL(PC.get(ednsQ, r2), false);
     /* inserting the EDNS-enabled one too */
-    PC.insert(ednsQ, r, 3600);
+    PC.insert(ednsQ, r, 3600, "");
     BOOST_CHECK_EQUAL(PC.size(), 2U);
 
     /* different EDNS versions, should not match */
@@ -422,7 +425,7 @@ BOOST_AUTO_TEST_CASE(test_AuthPacketCache) {
 
     /* inserting the version with ECS Client Subnet set,
      it should NOT replace the existing EDNS one. */
-    PC.insert(ecs1, r, 3600);
+    PC.insert(ecs1, r, 3600, "");
     BOOST_CHECK_EQUAL(PC.size(), 3U);
 
     /* different subnet of same size, should NOT match
@@ -437,7 +440,7 @@ BOOST_AUTO_TEST_CASE(test_AuthPacketCache) {
     BOOST_CHECK_EQUAL(PC.get(q, r2), false);
     BOOST_CHECK_EQUAL(PC.size(), 0U);
 
-    PC.insert(q, r, 3600);
+    PC.insert(q, r, 3600, "");
     BOOST_CHECK_EQUAL(PC.size(), 1U);
     BOOST_CHECK_EQUAL(PC.get(q, r2), true);
     BOOST_CHECK_EQUAL(r2.qdomain, r.qdomain);
@@ -445,7 +448,7 @@ BOOST_AUTO_TEST_CASE(test_AuthPacketCache) {
     BOOST_CHECK_EQUAL(PC.get(q, r2), false);
     BOOST_CHECK_EQUAL(PC.size(), 0U);
 
-    PC.insert(q, r, 3600);
+    PC.insert(q, r, 3600, "");
     BOOST_CHECK_EQUAL(PC.size(), 1U);
     BOOST_CHECK_EQUAL(PC.get(q, r2), true);
     BOOST_CHECK_EQUAL(r2.qdomain, r.qdomain);
@@ -453,7 +456,7 @@ BOOST_AUTO_TEST_CASE(test_AuthPacketCache) {
     BOOST_CHECK_EQUAL(PC.get(q, r2), false);
     BOOST_CHECK_EQUAL(PC.size(), 0U);
 
-    PC.insert(q, r, 3600);
+    PC.insert(q, r, 3600, "");
     BOOST_CHECK_EQUAL(PC.size(), 1U);
     BOOST_CHECK_EQUAL(PC.get(q, r2), true);
     BOOST_CHECK_EQUAL(r2.qdomain, r.qdomain);
@@ -461,7 +464,7 @@ BOOST_AUTO_TEST_CASE(test_AuthPacketCache) {
     BOOST_CHECK_EQUAL(PC.get(q, r2), false);
     BOOST_CHECK_EQUAL(PC.size(), 0U);
 
-    PC.insert(q, r, 3600);
+    PC.insert(q, r, 3600, "");
     BOOST_CHECK_EQUAL(PC.size(), 1U);
     BOOST_CHECK_EQUAL(PC.purge("www.powerdns.net"), 0U);
     BOOST_CHECK_EQUAL(PC.get(q, r2), true);
@@ -482,4 +485,230 @@ BOOST_AUTO_TEST_CASE(test_AuthPacketCache) {
   }
 }
 
+static void feedPacketCache(AuthPacketCache& PC, uint32_t bits, const std::string& view) // NOLINT(readability-identifier-length)
+{
+  for (unsigned int counter = 0; counter < 128; ++counter) {
+    std::vector<uint8_t> storage;
+    DNSName qname = DNSName("network" + std::to_string(counter));
+    DNSPacketWriter qwriter(storage, qname, QType::A);
+    DNSPacket query(true);
+    query.parse(reinterpret_cast<char*>(storage.data()), storage.size()); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast): can't static_cast because of sign difference
+    storage.clear();
+    DNSPacketWriter rwriter(storage, qname, QType::A);
+    rwriter.startRecord(qname, QType::A, 3600, QClass::IN, DNSResourceRecord::ANSWER);
+    rwriter.xfrIP(htonl((counter << 24) | bits));
+    rwriter.commit();
+    DNSPacket response(false);
+    response.parse(reinterpret_cast<char*>(storage.data()), storage.size()); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast): can't static_cast because of sign difference
+    // magic copied from threadPCMangler() above
+    query.setHash(PacketCache::canHashPacket(query.getString()));
+    PC.insert(query, response, 2600, view);
+  }
+}
+
+static void slurpPacketCache(AuthPacketCache& PC, const std::string& bits, const std::string& view) // NOLINT(readability-identifier-length)
+{
+  for (unsigned int counter = 0; counter < 128; ++counter) {
+    std::vector<uint8_t> storage;
+    DNSName qname = DNSName("network" + std::to_string(counter));
+    DNSPacketWriter qwriter(storage, qname, QType::A);
+    DNSPacket query(true);
+    query.parse(reinterpret_cast<char*>(storage.data()), storage.size()); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast): can't static_cast because of sign difference
+
+    DNSPacket response(false);
+    bool hit = PC.get(query, response, view);
+    BOOST_CHECK_EQUAL(hit, true);
+    if (!hit) {
+      continue;
+    }
+    BOOST_CHECK_EQUAL(response.qdomain, query.qdomain);
+    const std::string& wiresponse = response.getString();
+    MOADNSParser parser(false, wiresponse.c_str(), wiresponse.size());
+    BOOST_REQUIRE_EQUAL(parser.d_answers.size(), 1U);
+    const auto& record = parser.d_answers.at(0);
+    BOOST_REQUIRE_EQUAL(record.d_type, QType::A);
+    BOOST_REQUIRE_EQUAL(record.d_class, QClass::IN);
+    auto content = getRR<ARecordContent>(record);
+    BOOST_REQUIRE(content);
+    BOOST_REQUIRE_EQUAL(content->getCA().toString(), std::to_string(counter) + bits);
+  }
+}
+
+BOOST_AUTO_TEST_CASE(test_AuthPacketCacheNetmasks) {
+  try {
+    ::arg().setSwitch("no-shuffle","Set this to prevent random shuffling of answers - for regression testing")="off";
+
+    AuthPacketCache PC; // NOLINT(readability-identifier-length) 
+    PC.setMaxEntries(1000000);
+    PC.setTTL(0xc0ffee); // cache works better when programmer is cafeinated and doesn't forget to enable it
+
+    std::string view1{"view1"};
+    std::string view2{"view2"};
+
+    // Set up a few packets with no view.
+    feedPacketCache(PC, 0x00010203, "");
+    BOOST_REQUIRE_EQUAL(PC.size(), 128 * 1);
+
+    // Set up a few packets with a view and different A result.
+    feedPacketCache(PC, 0x00020406, view1);
+    BOOST_REQUIRE_EQUAL(PC.size(), 128 * 2);
+
+    // Set up a few packets with yet another view and yet another different A result.
+    feedPacketCache(PC, 0x00030609, view2);
+    BOOST_REQUIRE_EQUAL(PC.size(), 128 * 3);
+
+    // Now check that we are getting cache hits for all the packets we've added,
+    // with the correct answers
+    slurpPacketCache(PC, ".1.2.3", "");
+    slurpPacketCache(PC, ".2.4.6", view1);
+    slurpPacketCache(PC, ".3.6.9", view2);
+  }
+  catch(PDNSException& e) {
+    cerr<<"Had error in AuthPacketCache: "<<e.reason<<endl;
+    throw;
+  }
+}
+
+#ifdef PDNS_AUTH // [
+// Combined packet cache and zone cache test to exercize views
+
+static DNSPacket buildQuery(const DNSName& qname)
+{
+  std::vector<uint8_t> storage;
+  DNSPacketWriter qwriter(storage, qname, QType::A);
+  DNSPacket query(true);
+  query.parse(reinterpret_cast<char*>(storage.data()), storage.size()); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast): can't static_cast because of sign difference
+  storage.clear();
+  // magic copied from threadPCMangler() above
+  query.setHash(PacketCache::canHashPacket(query.getString()));
+  return query;
+}
+
+static void feedPacketCache2(AuthPacketCache& PC, const std::string& view, uint32_t ipAddress, const DNSName& qname) // NOLINT(readability-identifier-length)
+{
+  DNSPacket query = buildQuery(qname);
+
+  std::vector<uint8_t> storage;
+  DNSPacketWriter rwriter(storage, qname, QType::A);
+  rwriter.startRecord(qname, QType::A, 3600, QClass::IN, DNSResourceRecord::ANSWER);
+  rwriter.xfrIP(htonl(ipAddress));
+  rwriter.commit();
+  DNSPacket response(false);
+  response.parse(reinterpret_cast<char*>(storage.data()), storage.size()); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast): can't static_cast because of sign difference
+
+  PC.insert(query, response, 2600, view);
+}
+
+static bool queryPacketCache2(AuthPacketCache& PC, AuthZoneCache& ZC, ComboAddress from, const DNSName& qname, const Netmask& expectedMask, const std::string& expectedView, const std::string& expectedAddress) // NOLINT(readability-identifier-length)
+{
+  DNSPacket query = buildQuery(qname);
+  DNSPacket response(false);
+
+  Netmask netmask(from);
+  std::string view = ZC.getViewFromNetwork(&netmask);
+  BOOST_REQUIRE(netmask == expectedMask);
+  BOOST_REQUIRE(view == expectedView);
+
+  bool hit = PC.get(query, response, view);
+  if (hit) {
+    BOOST_CHECK_EQUAL(response.qdomain, query.qdomain);
+    const std::string& wiresponse = response.getString();
+    MOADNSParser parser(false, wiresponse.c_str(), wiresponse.size());
+    BOOST_REQUIRE_EQUAL(parser.d_answers.size(), 1U);
+    const auto& record = parser.d_answers.at(0);
+    BOOST_REQUIRE_EQUAL(record.d_type, QType::A);
+    BOOST_REQUIRE_EQUAL(record.d_class, QClass::IN);
+    auto content = getRR<ARecordContent>(record);
+    BOOST_REQUIRE(content);
+    BOOST_REQUIRE_EQUAL(content->getCA().toString(), expectedAddress);
+  }
+  return hit;
+}
+
+BOOST_AUTO_TEST_CASE(test_AuthViews)
+{
+  // Setup Zone Cache
+
+  AuthZoneCache ZC; // NOLINT(readability-identifier-length) 
+  ZC.setRefreshInterval(3600);
+
+  // Declare a few zones
+  ZoneName foo("example.com..foo");
+  ZoneName bar("example.com..bar");
+  ZC.add(foo, static_cast<domainid_t>('F'));
+  ZC.add(bar, static_cast<domainid_t>('B'));
+
+  // Declare a few networks
+  std::string view1{"view1"};
+  std::string view2{"view2"};
+  Netmask outerMask("192.0.2.0/24");
+  Netmask innerMask("192.0.2.0/25");
+  ZC.updateNetwork(outerMask, view1);
+  ZC.updateNetwork(innerMask, view2);
+
+  // Declare a few views
+  ZC.addToView(view1, foo);
+  ZC.addToView(view2, bar);
+
+  // Setup Packet Cache
+
+  AuthPacketCache PC; // NOLINT(readability-identifier-length) 
+  PC.setMaxEntries(1000000);
+  PC.setTTL(0xc0ffee); // cache works better when programmer is cafeinated and doesn't forget to enable it
+
+  // Cache answer for query in view2
+  DNSName qname("example.com");
+  feedPacketCache2(PC, view2, 0x02020202, qname);
+  BOOST_CHECK_EQUAL(PC.size(), 1);
+
+  // Check that requesting from view1 causes a cache miss
+  BOOST_CHECK_EQUAL(queryPacketCache2(PC, ZC, ComboAddress("192.0.2.128"), qname, outerMask, view1, "1.1.1.1"), false);
+
+  // Check that requesting from view2 causes a cache hit
+  BOOST_CHECK_EQUAL(queryPacketCache2(PC, ZC, ComboAddress("192.0.2.1"), qname, innerMask, view2, "2.2.2.2"), true);
+
+  // Cache answer for query in view1
+  feedPacketCache2(PC, view1, 0x01010101, qname);
+  BOOST_CHECK_EQUAL(PC.size(), 2);
+
+  // Check that requesting from view1 causes a cache hit with the right data
+  BOOST_CHECK_EQUAL(queryPacketCache2(PC, ZC, ComboAddress("192.0.2.128"), qname, outerMask, view1, "1.1.1.1"), true);
+
+  // Check that requesting from view2 causes a cache hit with the right data
+  BOOST_CHECK_EQUAL(queryPacketCache2(PC, ZC, ComboAddress("192.0.2.1"), qname, innerMask, view2, "2.2.2.2"), true);
+
+  // Purge view2
+  BOOST_CHECK_EQUAL(PC.purgeExact(view2, qname), 1);
+  BOOST_CHECK_EQUAL(PC.size(), 1);
+
+  // Check that requesting from view2 causes a cache miss
+  BOOST_CHECK_EQUAL(queryPacketCache2(PC, ZC, ComboAddress("192.0.2.1"), qname, innerMask, view2, "2.2.2.2"), false);
+
+  // Check that requesting from view1 causes a cache hit with the right data
+  BOOST_CHECK_EQUAL(queryPacketCache2(PC, ZC, ComboAddress("192.0.2.128"), qname, outerMask, view1, "1.1.1.1"), true);
+
+  // Purge view1
+  BOOST_CHECK_EQUAL(PC.purgeExact(view1, qname), 1);
+  BOOST_CHECK_EQUAL(PC.size(), 0);
+
+  // Check that requesting from view1 causes a cache miss
+  BOOST_CHECK_EQUAL(queryPacketCache2(PC, ZC, ComboAddress("192.0.2.128"), qname, outerMask, view1, "1.1.1.1"), false);
+
+  // Cache answers for view1 and view2 again
+  feedPacketCache2(PC, view1, 0x01010101, qname);
+  feedPacketCache2(PC, view2, 0x02020202, qname);
+  BOOST_CHECK_EQUAL(PC.size(), 2);
+
+  // Purge all views
+  BOOST_CHECK_EQUAL(PC.purgeExact(qname), 2);
+  BOOST_CHECK_EQUAL(PC.size(), 0);
+
+  // Check that requesting from view1 causes a cache miss
+  BOOST_CHECK_EQUAL(queryPacketCache2(PC, ZC, ComboAddress("192.0.2.128"), qname, outerMask, view1, "1.1.1.1"), false);
+
+  // Check that requesting from view2 causes a cache miss
+  BOOST_CHECK_EQUAL(queryPacketCache2(PC, ZC, ComboAddress("192.0.2.1"), qname, innerMask, view2, "2.2.2.2"), false);
+}
+#endif // ] PDNS_AUTH
+
 BOOST_AUTO_TEST_SUITE_END()
index b62dc57179cc05f60c9cc3cf127c660449fa22bb..c42a874a2ed4168b18192066795ed218a939a796 100644 (file)
@@ -540,8 +540,9 @@ bool UeberBackend::getAuth(const ZoneName& target, const QType& qtype, SOAData*
   if (g_zoneCache.isEnabled()) {
     Netmask _remote(remote);
     view = g_zoneCache.getViewFromNetwork(&_remote);
-    // Remember the view netmask, if applicable, for ECS responses.
+    // Remember the view and its netmask, if applicable, for ECS responses.
     if (!view.empty() && pkt_p != nullptr) {
+      pkt_p->d_view = view;
       pkt_p->d_span = _remote;
     }
   }
index 64a45758ee97892657434f3ef25f7e7ce072dba9..f8557b12f544f7491d732df884f00b68d364399a 100644 (file)
@@ -48,6 +48,7 @@
 #include "zoneparser-tng.hh"
 #include "auth-main.hh"
 #include "auth-caches.hh"
+#include "auth-packetcache.hh"
 #include "auth-zonecache.hh"
 #include "threadname.hh"
 #include "tsigutils.hh"
@@ -2718,6 +2719,10 @@ static void apiServerViewsPOST(HttpRequest* req, HttpResponse* resp)
   if (g_zoneCache.isEnabled()) {
     g_zoneCache.addToView(view, zonename);
   }
+  // Purge packet cache for that zone
+  if (PC.enabled()) {
+    (void)PC.purgeExact(view, zonename.operator const DNSName&());
+  }
 
   resp->body = "";
   resp->status = 204;
@@ -2736,6 +2741,10 @@ static void apiServerViewsDELETE(HttpRequest* req, HttpResponse* resp)
   if (g_zoneCache.isEnabled()) {
     g_zoneCache.removeFromView(view, zoneData.zoneName);
   }
+  // Purge packet cache for that zone
+  if (PC.enabled()) {
+    (void)PC.purgeExact(view, zoneData.zoneName.operator const DNSName&());
+  }
 
   resp->body = "";
   resp->status = 204;