]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Backport to rec-4.9.x: limit maximum size of rr sets in record cache 14745/head rec-4.9.9
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Mon, 26 Aug 2024 12:05:01 +0000 (14:05 +0200)
committerOtto Moerbeek <otto.moerbeek@open-xchange.com>
Mon, 26 Aug 2024 12:23:23 +0000 (14:23 +0200)
pdns/recursordist/rec-main.cc
pdns/recursordist/recursor_cache.cc
pdns/recursordist/recursor_cache.hh
pdns/recursordist/test-recursorcache_cc.cc
pdns/recursordist/test-syncres_cc.cc
pdns/recursordist/test-syncres_cc3.cc
pdns/recursordist/test-syncres_cc8.cc
pdns/uuid-utils.cc

index 57c5995545faea03e9f4a468204ffe3dfc06f0bf..2512546504e194c64fb3ce120d41be082455bd07 100644 (file)
@@ -1606,6 +1606,8 @@ static int initSyncRes(Logr::log_t log, const std::optional<std::string>& myHost
     MemRecursorCache::s_maxServedStaleExtensions = sse;
     NegCache::s_maxServedStaleExtensions = sse;
   }
+  MemRecursorCache::s_maxRRSetSize = ::arg().asNum("max-rrset-size");
+  MemRecursorCache::s_limitQTypeAny = ::arg().mustDo("limit-qtype-any");
 
   if (SyncRes::s_tcp_fast_open_connect) {
     checkFastOpenSysctl(true, log);
@@ -3079,6 +3081,8 @@ static void initArgs()
   ::arg().setSwitch("save-parent-ns-set", "Save parent NS set to be used if child NS set fails") = "yes";
   ::arg().set("max-busy-dot-probes", "Maximum number of concurrent DoT probes") = "0";
   ::arg().set("serve-stale-extensions", "Number of times a record's ttl is extended by 30s to be served stale") = "0";
+  ::arg().set("max-rrset-size", "Maximum size of RRSet in cache") = "256";
+  ::arg().setSwitch("limit-qtype-any", "Limit answers to ANY queries in size") = "yes";
 
   ::arg().setCmd("help", "Provide a helpful message");
   ::arg().setCmd("version", "Print version string");
index 0d237cd750895b0c31fc60ec33dfd806c5e5642b..faee38f6d1ac61e2b862dc13a18392cc6c4204e4 100644 (file)
@@ -53,6 +53,8 @@
  */
 
 uint16_t MemRecursorCache::s_maxServedStaleExtensions;
+uint16_t MemRecursorCache::s_maxRRSetSize = 256;
+bool MemRecursorCache::s_limitQTypeAny = true;
 
 MemRecursorCache::MemRecursorCache(size_t mapsCount) :
   d_maps(mapsCount == 0 ? 1 : mapsCount)
@@ -143,6 +145,9 @@ static void updateDNSSECValidationStateFromCache(boost::optional<vState>& state,
 time_t MemRecursorCache::handleHit(MapCombo::LockedContent& content, MemRecursorCache::OrderedTagIterator_t& entry, const DNSName& qname, uint32_t& origTTL, vector<DNSRecord>* res, vector<std::shared_ptr<const RRSIGRecordContent>>* signatures, std::vector<std::shared_ptr<DNSRecord>>* authorityRecs, bool* variable, boost::optional<vState>& state, bool* wasAuth, DNSName* fromAuthZone, ComboAddress* fromAuthIP)
 {
   // MUTEX SHOULD BE ACQUIRED (as indicated by the reference to the content which is protected by a lock)
+  if (entry->d_tooBig) {
+    throw ImmediateServFailException("too many records in RRSet");
+  }
   time_t ttd = entry->d_ttd;
   origTTL = entry->d_orig_ttl;
 
@@ -151,6 +156,10 @@ time_t MemRecursorCache::handleHit(MapCombo::LockedContent& content, MemRecursor
   }
 
   if (res != nullptr) {
+    if (s_limitQTypeAny && res->size() + entry->d_records.size() > s_maxRRSetSize) {
+      throw ImmediateServFailException("too many records in result");
+    }
+
     res->reserve(res->size() + entry->d_records.size());
 
     for (const auto& record : entry->d_records) {
@@ -346,7 +355,7 @@ time_t MemRecursorCache::fakeTTD(MemRecursorCache::OrderedTagIterator_t& entry,
   return ttl;
 }
 // returns -1 for no hits
-time_t MemRecursorCache::get(time_t now, const DNSName& qname, const QType qt, Flags flags, vector<DNSRecord>* res, const ComboAddress& who, const OptTag& routingTag, vector<std::shared_ptr<const RRSIGRecordContent>>* signatures, std::vector<std::shared_ptr<DNSRecord>>* authorityRecs, bool* variable, vState* state, bool* wasAuth, DNSName* fromAuthZone, ComboAddress* fromAuthIP)
+time_t MemRecursorCache::get(time_t now, const DNSName& qname, const QType qt, Flags flags, vector<DNSRecord>* res, const ComboAddress& who, const OptTag& routingTag, vector<std::shared_ptr<const RRSIGRecordContent>>* signatures, std::vector<std::shared_ptr<DNSRecord>>* authorityRecs, bool* variable, vState* state, bool* wasAuth, DNSName* fromAuthZone, ComboAddress* fromAuthIP) // NOLINT(readability-function-cognitive-complexity)
 {
   bool requireAuth = flags & RequireAuth;
   bool refresh = flags & Refresh;
@@ -410,8 +419,8 @@ time_t MemRecursorCache::get(time_t now, const DNSName& qname, const QType qt, F
 
   if (routingTag) {
     auto entries = getEntries(*lockedShard, qname, qt, routingTag);
-    bool found = false;
-    time_t ttd;
+    unsigned int found = 0;
+    time_t ttd{};
 
     if (entries.first != entries.second) {
       OrderedTagIterator_t firstIndexIterator;
@@ -427,17 +436,20 @@ time_t MemRecursorCache::get(time_t now, const DNSName& qname, const QType qt, F
         if (!entryMatches(firstIndexIterator, qtype, requireAuth, who)) {
           continue;
         }
-        found = true;
+        ++found;
 
         handleServeStaleBookkeeping(now, serveStale, firstIndexIterator);
 
         ttd = handleHit(*lockedShard, firstIndexIterator, qname, origTTL, res, signatures, authorityRecs, variable, cachedState, wasAuth, fromAuthZone, fromAuthIP);
 
-        if (qt != QType::ANY && qt != QType::ADDR) { // normally if we have a hit, we are done
+        if (qt == QType::ADDR && found == 2) {
+          break;
+        }
+        if (qt != QType::ANY) { // normally if we have a hit, we are done
           break;
         }
       }
-      if (found) {
+      if (found > 0) {
         if (state && cachedState) {
           *state = *cachedState;
         }
@@ -453,8 +465,8 @@ time_t MemRecursorCache::get(time_t now, const DNSName& qname, const QType qt, F
 
   if (entries.first != entries.second) {
     OrderedTagIterator_t firstIndexIterator;
-    bool found = false;
-    time_t ttd;
+    unsigned int found = 0;
+    time_t ttd{};
 
     for (auto i = entries.first; i != entries.second; ++i) {
       firstIndexIterator = lockedShard->d_map.project<OrderedTag>(i);
@@ -468,17 +480,20 @@ time_t MemRecursorCache::get(time_t now, const DNSName& qname, const QType qt, F
       if (!entryMatches(firstIndexIterator, qtype, requireAuth, who)) {
         continue;
       }
-      found = true;
+      ++found;
 
       handleServeStaleBookkeeping(now, serveStale, firstIndexIterator);
 
       ttd = handleHit(*lockedShard, firstIndexIterator, qname, origTTL, res, signatures, authorityRecs, variable, cachedState, wasAuth, fromAuthZone, fromAuthIP);
 
-      if (qt != QType::ANY && qt != QType::ADDR) { // normally if we have a hit, we are done
+      if (qt == QType::ADDR && found == 2) {
+        break;
+      }
+      if (qt != QType::ANY) { // normally if we have a hit, we are done
         break;
       }
     }
-    if (found) {
+    if (found > 0) {
       if (state && cachedState) {
         *state = *cachedState;
       }
@@ -594,7 +609,6 @@ void MemRecursorCache::replace(time_t now, const DNSName& qname, const QType qt,
   ce.d_signatures = signatures;
   ce.d_authorityRecs = authorityRecs;
   ce.d_records.clear();
-  ce.d_records.reserve(content.size());
   ce.d_authZone = authZone;
   if (from) {
     ce.d_from = *from;
@@ -603,10 +617,19 @@ void MemRecursorCache::replace(time_t now, const DNSName& qname, const QType qt,
     ce.d_from = ComboAddress();
   }
 
-  for (const auto& i : content) {
+  size_t toStore = content.size();
+  if (toStore <= s_maxRRSetSize) {
+    ce.d_tooBig = false;
+  }
+  else {
+    toStore = 1; // record cache does not like empty RRSets
+    ce.d_tooBig = true;
+  }
+  ce.d_records.reserve(toStore);
+  for (const auto& record : content) {
     /* Yes, we have altered the d_ttl value by adding time(nullptr) to it
        prior to calling this function, so the TTL actually holds a TTD. */
-    ce.d_ttd = min(maxTTD, static_cast<time_t>(i.d_ttl)); // XXX this does weird things if TTLs differ in the set
+    ce.d_ttd = min(maxTTD, static_cast<time_t>(record.d_ttl)); // XXX this does weird things if TTLs differ in the set
 
     ce.d_orig_ttl = ce.d_ttd - ttl_time;
     // Even though we record the time the ttd was computed, there still seems to be a case where the computed
@@ -615,7 +638,10 @@ void MemRecursorCache::replace(time_t now, const DNSName& qname, const QType qt,
     if (ce.d_orig_ttl < SyncRes::s_minimumTTL || ce.d_orig_ttl > SyncRes::s_maxcachettl) {
       ce.d_orig_ttl = SyncRes::s_minimumTTL;
     }
-    ce.d_records.push_back(i.getContent());
+    ce.d_records.push_back(record.getContent());
+    if (--toStore == 0) {
+      break;
+    }
   }
 
   if (!isNew) {
@@ -800,7 +826,7 @@ uint64_t MemRecursorCache::doDump(int fd, size_t maxCacheEntries)
       for (const auto& j : i.d_records) {
         count++;
         try {
-          fprintf(fp.get(), "%s %" PRIu32 " %" PRId64 " IN %s %s ; (%s) auth=%i zone=%s from=%s nm=%s rtag=%s ss=%hd\n", i.d_qname.toString().c_str(), i.d_orig_ttl, static_cast<int64_t>(i.d_ttd - now), i.d_qtype.toString().c_str(), j->getZoneRepresentation().c_str(), vStateToString(i.d_state).c_str(), i.d_auth, i.d_authZone.toLogString().c_str(), i.d_from.toString().c_str(), i.d_netmask.empty() ? "" : i.d_netmask.toString().c_str(), !i.d_rtag ? "" : i.d_rtag.get().c_str(), i.d_servedStale);
+          fprintf(fp.get(), "%s %" PRIu32 " %" PRId64 " IN %s %s ; (%s) auth=%i zone=%s from=%s nm=%s rtag=%s ss=%hd%s\n", i.d_qname.toString().c_str(), i.d_orig_ttl, static_cast<int64_t>(i.d_ttd - now), i.d_qtype.toString().c_str(), j->getZoneRepresentation().c_str(), vStateToString(i.d_state).c_str(), i.d_auth, i.d_authZone.toLogString().c_str(), i.d_from.toString().c_str(), i.d_netmask.empty() ? "" : i.d_netmask.toString().c_str(), !i.d_rtag ? "" : i.d_rtag.get().c_str(), i.d_servedStale, i.d_tooBig ? "(big)" : "");
         }
         catch (...) {
           fprintf(fp.get(), "; error printing '%s'\n", i.d_qname.empty() ? "EMPTY" : i.d_qname.toString().c_str());
index 379249b610021b5292230cfc4127a853be1b96c1..55b769ac49055a4bedb39f21e4037951859dd1ff 100644 (file)
@@ -58,6 +58,10 @@ public:
   size_t bytes();
   pair<uint64_t, uint64_t> stats();
   size_t ecsIndexSize();
+  // Maximum size of RRSet we are willing to cache. If the RRSet is larger, we do create an entry,
+  // but mark it as too big. Subsequent gets will cause an ImmediateServFailException to be thrown.
+  static uint16_t s_maxRRSetSize;
+  static bool s_limitQTypeAny;
 
   typedef boost::optional<std::string> OptTag;
 
@@ -124,6 +128,7 @@ private:
     QType d_qtype;
     bool d_auth;
     mutable bool d_submitted; // whether this entry has been queued for refetch
+    bool d_tooBig{false};
   };
 
   /* The ECS Index (d_ecsIndex) keeps track of whether there is any ECS-specific
index e456d3f11681e48827834846c2cfd244890802e7..6836e45f08b5b929b5da43e805d8cc22b9c0aad4 100644 (file)
@@ -8,6 +8,7 @@
 
 #include "iputils.hh"
 #include "recursor_cache.hh"
+#include "syncres.hh"
 
 BOOST_AUTO_TEST_SUITE(recursorcache_cc)
 
@@ -158,6 +159,7 @@ static void simple(time_t now)
     BOOST_CHECK_EQUAL(retrieved.size(), 0U);
 
     // QType::ANY should return any qtype, so from the right subnet we should get all of them
+    MemRecursorCache::s_limitQTypeAny = false;
     BOOST_CHECK_EQUAL(MRC.get(now, power, QType(QType::ANY), MemRecursorCache::None, &retrieved, ComboAddress("192.0.2.3")), (ttd - now));
     BOOST_CHECK_EQUAL(retrieved.size(), 3U);
     for (const auto& rec : retrieved) {
@@ -385,6 +387,52 @@ BOOST_AUTO_TEST_CASE(test_RecursorCacheSimpleDistantFuture)
 }
 #endif
 
+BOOST_AUTO_TEST_CASE(test_RecursorCacheBig)
+{
+  MemRecursorCache MRC;
+
+  std::vector<DNSRecord> records;
+  std::vector<DNSRecord> retrieved;
+  const DNSName authZone(".");
+
+  time_t now = time(nullptr);
+  time_t ttd = now + 30;
+  DNSName power("powerdns.com.");
+  DNSRecord dr0;
+  string dr0Content("2001:DB8::");
+  dr0.d_name = power;
+  dr0.d_type = QType::AAAA;
+  dr0.d_class = QClass::IN;
+  dr0.d_ttl = static_cast<uint32_t>(ttd); // XXX truncation
+  dr0.d_place = DNSResourceRecord::ANSWER;
+  for (int i = 0; i < MemRecursorCache::s_maxRRSetSize; i++) {
+    dr0.setContent(std::make_shared<AAAARecordContent>(dr0Content + std::to_string(i)));
+    records.push_back(dr0);
+  }
+
+  // This one should fit
+  MRC.replace(now, power, QType::AAAA, records, {}, {}, true, authZone, boost::none);
+  BOOST_CHECK_EQUAL(MRC.size(), 1U);
+  BOOST_CHECK_EQUAL(MRC.get(now, power, QType(QType::AAAA), MemRecursorCache::None, &retrieved, ComboAddress()), (ttd - now));
+  BOOST_CHECK_EQUAL(retrieved.size(), MemRecursorCache::s_maxRRSetSize);
+
+  dr0.setContent(std::make_shared<AAAARecordContent>(dr0Content + std::to_string(MemRecursorCache::s_maxRRSetSize)));
+  records.push_back(dr0);
+  // This one is too large and should throw exception
+  MRC.replace(now, power, QType::AAAA, records, {}, {}, true, authZone, boost::none);
+  BOOST_CHECK_EQUAL(MRC.size(), 1U);
+
+  BOOST_CHECK_THROW((void)MRC.get(now, power, QType(QType::AAAA), MemRecursorCache::None, &retrieved, ComboAddress()),
+                    ImmediateServFailException);
+
+  records.resize(1);
+  // This one should fit again
+  MRC.replace(now, power, QType::AAAA, records, {}, {}, true, authZone, boost::none);
+  BOOST_CHECK_EQUAL(MRC.size(), 1U);
+  BOOST_CHECK_EQUAL(MRC.get(now, power, QType(QType::AAAA), MemRecursorCache::None, &retrieved, ComboAddress()), (ttd - now));
+  BOOST_CHECK_EQUAL(retrieved.size(), 1U);
+}
+
 BOOST_AUTO_TEST_CASE(test_RecursorCacheGhost)
 {
   MemRecursorCache MRC;
index 0b0751ae0861b5978da118397d52dc9a8f2215ac..bcbeaf68641181885f69100867671bea0a3787b8 100644 (file)
@@ -140,6 +140,8 @@ void initSR(bool debug)
   }
 
   MemRecursorCache::s_maxServedStaleExtensions = 0;
+  MemRecursorCache::s_maxRRSetSize = 100;
+  MemRecursorCache::s_limitQTypeAny = true;
   NegCache::s_maxServedStaleExtensions = 0;
   g_recCache = std::make_unique<MemRecursorCache>();
   g_negCache = std::make_unique<NegCache>();
index 77ca592c9faf1d97da20d04898014018a2e1b35c..3e49ebea213821281cfc28b1a67f7ecf8e668d56 100644 (file)
@@ -51,27 +51,49 @@ BOOST_AUTO_TEST_CASE(test_unauth_any)
 
   const DNSName target("powerdns.com.");
 
-  sr->setAsyncCallback([target](const ComboAddress& ip, const DNSName& domain, int /* type */, bool /* doTCP */, bool /* sendRDQuery */, int /* EDNS0Level */, struct timeval* /* now */, boost::optional<Netmask>& /* srcmask */, boost::optional<const ResolveContext&> /* context */, LWResult* res, bool* /* chained */) {
-    if (isRootServer(ip)) {
+  sr->setAsyncCallback([&](const ComboAddress& address, const DNSName& domain, int type, bool /* doTCP */, bool /* sendRDQuery */, int /* EDNS0Level */, struct timeval* /* now */, boost::optional<Netmask>& /* srcmask */, boost::optional<const ResolveContext&> /* context */, LWResult* res, bool* /* chained */) {
+    if (isRootServer(address)) {
       setLWResult(res, 0, false, false, true);
       addRecordToLW(res, "com.", QType::NS, "a.gtld-servers.net.", DNSResourceRecord::AUTHORITY, 172800);
       addRecordToLW(res, "a.gtld-servers.net.", QType::A, "192.0.2.1", DNSResourceRecord::ADDITIONAL, 3600);
       return LWResult::Result::Success;
     }
-    else if (ip == ComboAddress("192.0.2.1:53")) {
-
-      setLWResult(res, 0, false, false, true);
-      addRecordToLW(res, domain, QType::A, "192.0.2.42");
-      return LWResult::Result::Success;
+    if (address == ComboAddress("192.0.2.1:53")) {
+      if (type == QType::A) {
+        setLWResult(res, 0, false, false, true);
+        addRecordToLW(res, domain, QType::A, "192.0.2.42");
+        addRecordToLW(res, domain, QType::A, "192.0.2.43");
+        return LWResult::Result::Success;
+      }
+      if (type == QType::AAAA) {
+        setLWResult(res, 0, false, false, true);
+        addRecordToLW(res, domain, QType::AAAA, "::1");
+        return LWResult::Result::Success;
+      }
     }
 
     return LWResult::Result::Timeout;
   });
 
   vector<DNSRecord> ret;
-  int res = sr->beginResolve(target, QType(QType::ANY), QClass::IN, ret);
+  int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
+  BOOST_CHECK_EQUAL(res, RCode::NoError);
+  BOOST_CHECK_EQUAL(ret.size(), 2U);
+
+  ret.clear();
+  res = sr->beginResolve(target, QType(QType::AAAA), QClass::IN, ret);
   BOOST_CHECK_EQUAL(res, RCode::NoError);
   BOOST_CHECK_EQUAL(ret.size(), 1U);
+
+  ret.clear();
+  MemRecursorCache::s_maxRRSetSize = 2;
+  BOOST_CHECK_THROW(sr->beginResolve(target, QType(QType::ANY), QClass::IN, ret), ImmediateServFailException);
+
+  MemRecursorCache::s_limitQTypeAny = false;
+  ret.clear();
+  res = sr->beginResolve(target, QType(QType::ANY), QClass::IN, ret);
+  BOOST_CHECK_EQUAL(res, RCode::NoError);
+  BOOST_CHECK_EQUAL(ret.size(), 3U);
 }
 
 static void test_no_data_f(bool qmin)
index 231f42e66f497873f68982e6917ef28e5211a9aa..e16828e03264c020b1c512ca4b6098e7cf4e95a2 100644 (file)
@@ -1581,6 +1581,20 @@ BOOST_AUTO_TEST_CASE(test_dnssec_validation_from_cache_secure_any)
   ret.clear();
   /* third one _does_ require validation */
   sr->setDNSSECValidationRequested(true);
+  MemRecursorCache::s_maxRRSetSize = 1;
+  BOOST_CHECK_THROW(sr->beginResolve(target, QType(QType::ANY), QClass::IN, ret), ImmediateServFailException);
+  // BOOST_CHECK_EQUAL(res, RCode::NoError);
+  // BOOST_CHECK_EQUAL(sr->getValidationState(), vState::Secure);
+  // BOOST_REQUIRE_EQUAL(ret.size(), 2U);
+  // for (const auto& record : ret) {
+  //   BOOST_CHECK(record.d_type == QType::A || record.d_type == QType::AAAA || record.d_type == QType::RRSIG);
+  // }
+  BOOST_CHECK_EQUAL(queriesCount, 2U);
+
+  ret.clear();
+  /* next one _does_ require validation */
+  MemRecursorCache::s_limitQTypeAny = false;
+  sr->setDNSSECValidationRequested(true);
   res = sr->beginResolve(target, QType(QType::ANY), QClass::IN, ret);
   BOOST_CHECK_EQUAL(res, RCode::NoError);
   BOOST_CHECK_EQUAL(sr->getValidationState(), vState::Secure);
index c59e0a0d0daae040c81bbfbb29357017690e5814..301daff0bb1ee1e9390cee1708c8e6e7eb7bf031 100644 (file)
@@ -30,6 +30,7 @@
 #endif /* BOOST_PENDING_INTEGER_LOG2_HPP */
 #endif /* BOOST_VERSION */
 
+#include <boost/random/mersenne_twister.hpp>
 #include <boost/uuid/uuid_generators.hpp>
 
 // The default of: