From 4775860c55ede7717e6e5702a90632cae5efd28e Mon Sep 17 00:00:00 2001 From: Otto Moerbeek Date: Mon, 26 Aug 2024 14:05:01 +0200 Subject: [PATCH] Backport to rec-4.9.x: limit maximum size of rr sets in record cache --- pdns/recursordist/rec-main.cc | 4 ++ pdns/recursordist/recursor_cache.cc | 58 ++++++++++++++++------ pdns/recursordist/recursor_cache.hh | 5 ++ pdns/recursordist/test-recursorcache_cc.cc | 48 ++++++++++++++++++ pdns/recursordist/test-syncres_cc.cc | 2 + pdns/recursordist/test-syncres_cc3.cc | 38 +++++++++++--- pdns/recursordist/test-syncres_cc8.cc | 14 ++++++ pdns/uuid-utils.cc | 1 + 8 files changed, 146 insertions(+), 24 deletions(-) diff --git a/pdns/recursordist/rec-main.cc b/pdns/recursordist/rec-main.cc index 57c5995545..2512546504 100644 --- a/pdns/recursordist/rec-main.cc +++ b/pdns/recursordist/rec-main.cc @@ -1606,6 +1606,8 @@ static int initSyncRes(Logr::log_t log, const std::optional& myHost MemRecursorCache::s_maxServedStaleExtensions = sse; NegCache::s_maxServedStaleExtensions = sse; } + MemRecursorCache::s_maxRRSetSize = ::arg().asNum("max-rrset-size"); + MemRecursorCache::s_limitQTypeAny = ::arg().mustDo("limit-qtype-any"); if (SyncRes::s_tcp_fast_open_connect) { checkFastOpenSysctl(true, log); @@ -3079,6 +3081,8 @@ static void initArgs() ::arg().setSwitch("save-parent-ns-set", "Save parent NS set to be used if child NS set fails") = "yes"; ::arg().set("max-busy-dot-probes", "Maximum number of concurrent DoT probes") = "0"; ::arg().set("serve-stale-extensions", "Number of times a record's ttl is extended by 30s to be served stale") = "0"; + ::arg().set("max-rrset-size", "Maximum size of RRSet in cache") = "256"; + ::arg().setSwitch("limit-qtype-any", "Limit answers to ANY queries in size") = "yes"; ::arg().setCmd("help", "Provide a helpful message"); ::arg().setCmd("version", "Print version string"); diff --git a/pdns/recursordist/recursor_cache.cc b/pdns/recursordist/recursor_cache.cc index 0d237cd750..faee38f6d1 100644 --- a/pdns/recursordist/recursor_cache.cc +++ b/pdns/recursordist/recursor_cache.cc @@ -53,6 +53,8 @@ */ uint16_t MemRecursorCache::s_maxServedStaleExtensions; +uint16_t MemRecursorCache::s_maxRRSetSize = 256; +bool MemRecursorCache::s_limitQTypeAny = true; MemRecursorCache::MemRecursorCache(size_t mapsCount) : d_maps(mapsCount == 0 ? 1 : mapsCount) @@ -143,6 +145,9 @@ static void updateDNSSECValidationStateFromCache(boost::optional& state, time_t MemRecursorCache::handleHit(MapCombo::LockedContent& content, MemRecursorCache::OrderedTagIterator_t& entry, const DNSName& qname, uint32_t& origTTL, vector* res, vector>* signatures, std::vector>* authorityRecs, bool* variable, boost::optional& state, bool* wasAuth, DNSName* fromAuthZone, ComboAddress* fromAuthIP) { // MUTEX SHOULD BE ACQUIRED (as indicated by the reference to the content which is protected by a lock) + if (entry->d_tooBig) { + throw ImmediateServFailException("too many records in RRSet"); + } time_t ttd = entry->d_ttd; origTTL = entry->d_orig_ttl; @@ -151,6 +156,10 @@ time_t MemRecursorCache::handleHit(MapCombo::LockedContent& content, MemRecursor } if (res != nullptr) { + if (s_limitQTypeAny && res->size() + entry->d_records.size() > s_maxRRSetSize) { + throw ImmediateServFailException("too many records in result"); + } + res->reserve(res->size() + entry->d_records.size()); for (const auto& record : entry->d_records) { @@ -346,7 +355,7 @@ time_t MemRecursorCache::fakeTTD(MemRecursorCache::OrderedTagIterator_t& entry, return ttl; } // returns -1 for no hits -time_t MemRecursorCache::get(time_t now, const DNSName& qname, const QType qt, Flags flags, vector* res, const ComboAddress& who, const OptTag& routingTag, vector>* signatures, std::vector>* authorityRecs, bool* variable, vState* state, bool* wasAuth, DNSName* fromAuthZone, ComboAddress* fromAuthIP) +time_t MemRecursorCache::get(time_t now, const DNSName& qname, const QType qt, Flags flags, vector* res, const ComboAddress& who, const OptTag& routingTag, vector>* signatures, std::vector>* authorityRecs, bool* variable, vState* state, bool* wasAuth, DNSName* fromAuthZone, ComboAddress* fromAuthIP) // NOLINT(readability-function-cognitive-complexity) { bool requireAuth = flags & RequireAuth; bool refresh = flags & Refresh; @@ -410,8 +419,8 @@ time_t MemRecursorCache::get(time_t now, const DNSName& qname, const QType qt, F if (routingTag) { auto entries = getEntries(*lockedShard, qname, qt, routingTag); - bool found = false; - time_t ttd; + unsigned int found = 0; + time_t ttd{}; if (entries.first != entries.second) { OrderedTagIterator_t firstIndexIterator; @@ -427,17 +436,20 @@ time_t MemRecursorCache::get(time_t now, const DNSName& qname, const QType qt, F if (!entryMatches(firstIndexIterator, qtype, requireAuth, who)) { continue; } - found = true; + ++found; handleServeStaleBookkeeping(now, serveStale, firstIndexIterator); ttd = handleHit(*lockedShard, firstIndexIterator, qname, origTTL, res, signatures, authorityRecs, variable, cachedState, wasAuth, fromAuthZone, fromAuthIP); - if (qt != QType::ANY && qt != QType::ADDR) { // normally if we have a hit, we are done + if (qt == QType::ADDR && found == 2) { + break; + } + if (qt != QType::ANY) { // normally if we have a hit, we are done break; } } - if (found) { + if (found > 0) { if (state && cachedState) { *state = *cachedState; } @@ -453,8 +465,8 @@ time_t MemRecursorCache::get(time_t now, const DNSName& qname, const QType qt, F if (entries.first != entries.second) { OrderedTagIterator_t firstIndexIterator; - bool found = false; - time_t ttd; + unsigned int found = 0; + time_t ttd{}; for (auto i = entries.first; i != entries.second; ++i) { firstIndexIterator = lockedShard->d_map.project(i); @@ -468,17 +480,20 @@ time_t MemRecursorCache::get(time_t now, const DNSName& qname, const QType qt, F if (!entryMatches(firstIndexIterator, qtype, requireAuth, who)) { continue; } - found = true; + ++found; handleServeStaleBookkeeping(now, serveStale, firstIndexIterator); ttd = handleHit(*lockedShard, firstIndexIterator, qname, origTTL, res, signatures, authorityRecs, variable, cachedState, wasAuth, fromAuthZone, fromAuthIP); - if (qt != QType::ANY && qt != QType::ADDR) { // normally if we have a hit, we are done + if (qt == QType::ADDR && found == 2) { + break; + } + if (qt != QType::ANY) { // normally if we have a hit, we are done break; } } - if (found) { + if (found > 0) { if (state && cachedState) { *state = *cachedState; } @@ -594,7 +609,6 @@ void MemRecursorCache::replace(time_t now, const DNSName& qname, const QType qt, ce.d_signatures = signatures; ce.d_authorityRecs = authorityRecs; ce.d_records.clear(); - ce.d_records.reserve(content.size()); ce.d_authZone = authZone; if (from) { ce.d_from = *from; @@ -603,10 +617,19 @@ void MemRecursorCache::replace(time_t now, const DNSName& qname, const QType qt, ce.d_from = ComboAddress(); } - for (const auto& i : content) { + size_t toStore = content.size(); + if (toStore <= s_maxRRSetSize) { + ce.d_tooBig = false; + } + else { + toStore = 1; // record cache does not like empty RRSets + ce.d_tooBig = true; + } + ce.d_records.reserve(toStore); + for (const auto& record : content) { /* Yes, we have altered the d_ttl value by adding time(nullptr) to it prior to calling this function, so the TTL actually holds a TTD. */ - ce.d_ttd = min(maxTTD, static_cast(i.d_ttl)); // XXX this does weird things if TTLs differ in the set + ce.d_ttd = min(maxTTD, static_cast(record.d_ttl)); // XXX this does weird things if TTLs differ in the set ce.d_orig_ttl = ce.d_ttd - ttl_time; // Even though we record the time the ttd was computed, there still seems to be a case where the computed @@ -615,7 +638,10 @@ void MemRecursorCache::replace(time_t now, const DNSName& qname, const QType qt, if (ce.d_orig_ttl < SyncRes::s_minimumTTL || ce.d_orig_ttl > SyncRes::s_maxcachettl) { ce.d_orig_ttl = SyncRes::s_minimumTTL; } - ce.d_records.push_back(i.getContent()); + ce.d_records.push_back(record.getContent()); + if (--toStore == 0) { + break; + } } if (!isNew) { @@ -800,7 +826,7 @@ uint64_t MemRecursorCache::doDump(int fd, size_t maxCacheEntries) for (const auto& j : i.d_records) { count++; try { - fprintf(fp.get(), "%s %" PRIu32 " %" PRId64 " IN %s %s ; (%s) auth=%i zone=%s from=%s nm=%s rtag=%s ss=%hd\n", i.d_qname.toString().c_str(), i.d_orig_ttl, static_cast(i.d_ttd - now), i.d_qtype.toString().c_str(), j->getZoneRepresentation().c_str(), vStateToString(i.d_state).c_str(), i.d_auth, i.d_authZone.toLogString().c_str(), i.d_from.toString().c_str(), i.d_netmask.empty() ? "" : i.d_netmask.toString().c_str(), !i.d_rtag ? "" : i.d_rtag.get().c_str(), i.d_servedStale); + fprintf(fp.get(), "%s %" PRIu32 " %" PRId64 " IN %s %s ; (%s) auth=%i zone=%s from=%s nm=%s rtag=%s ss=%hd%s\n", i.d_qname.toString().c_str(), i.d_orig_ttl, static_cast(i.d_ttd - now), i.d_qtype.toString().c_str(), j->getZoneRepresentation().c_str(), vStateToString(i.d_state).c_str(), i.d_auth, i.d_authZone.toLogString().c_str(), i.d_from.toString().c_str(), i.d_netmask.empty() ? "" : i.d_netmask.toString().c_str(), !i.d_rtag ? "" : i.d_rtag.get().c_str(), i.d_servedStale, i.d_tooBig ? "(big)" : ""); } catch (...) { fprintf(fp.get(), "; error printing '%s'\n", i.d_qname.empty() ? "EMPTY" : i.d_qname.toString().c_str()); diff --git a/pdns/recursordist/recursor_cache.hh b/pdns/recursordist/recursor_cache.hh index 379249b610..55b769ac49 100644 --- a/pdns/recursordist/recursor_cache.hh +++ b/pdns/recursordist/recursor_cache.hh @@ -58,6 +58,10 @@ public: size_t bytes(); pair stats(); size_t ecsIndexSize(); + // Maximum size of RRSet we are willing to cache. If the RRSet is larger, we do create an entry, + // but mark it as too big. Subsequent gets will cause an ImmediateServFailException to be thrown. + static uint16_t s_maxRRSetSize; + static bool s_limitQTypeAny; typedef boost::optional OptTag; @@ -124,6 +128,7 @@ private: QType d_qtype; bool d_auth; mutable bool d_submitted; // whether this entry has been queued for refetch + bool d_tooBig{false}; }; /* The ECS Index (d_ecsIndex) keeps track of whether there is any ECS-specific diff --git a/pdns/recursordist/test-recursorcache_cc.cc b/pdns/recursordist/test-recursorcache_cc.cc index e456d3f116..6836e45f08 100644 --- a/pdns/recursordist/test-recursorcache_cc.cc +++ b/pdns/recursordist/test-recursorcache_cc.cc @@ -8,6 +8,7 @@ #include "iputils.hh" #include "recursor_cache.hh" +#include "syncres.hh" BOOST_AUTO_TEST_SUITE(recursorcache_cc) @@ -158,6 +159,7 @@ static void simple(time_t now) BOOST_CHECK_EQUAL(retrieved.size(), 0U); // QType::ANY should return any qtype, so from the right subnet we should get all of them + MemRecursorCache::s_limitQTypeAny = false; BOOST_CHECK_EQUAL(MRC.get(now, power, QType(QType::ANY), MemRecursorCache::None, &retrieved, ComboAddress("192.0.2.3")), (ttd - now)); BOOST_CHECK_EQUAL(retrieved.size(), 3U); for (const auto& rec : retrieved) { @@ -385,6 +387,52 @@ BOOST_AUTO_TEST_CASE(test_RecursorCacheSimpleDistantFuture) } #endif +BOOST_AUTO_TEST_CASE(test_RecursorCacheBig) +{ + MemRecursorCache MRC; + + std::vector records; + std::vector retrieved; + const DNSName authZone("."); + + time_t now = time(nullptr); + time_t ttd = now + 30; + DNSName power("powerdns.com."); + DNSRecord dr0; + string dr0Content("2001:DB8::"); + dr0.d_name = power; + dr0.d_type = QType::AAAA; + dr0.d_class = QClass::IN; + dr0.d_ttl = static_cast(ttd); // XXX truncation + dr0.d_place = DNSResourceRecord::ANSWER; + for (int i = 0; i < MemRecursorCache::s_maxRRSetSize; i++) { + dr0.setContent(std::make_shared(dr0Content + std::to_string(i))); + records.push_back(dr0); + } + + // This one should fit + MRC.replace(now, power, QType::AAAA, records, {}, {}, true, authZone, boost::none); + BOOST_CHECK_EQUAL(MRC.size(), 1U); + BOOST_CHECK_EQUAL(MRC.get(now, power, QType(QType::AAAA), MemRecursorCache::None, &retrieved, ComboAddress()), (ttd - now)); + BOOST_CHECK_EQUAL(retrieved.size(), MemRecursorCache::s_maxRRSetSize); + + dr0.setContent(std::make_shared(dr0Content + std::to_string(MemRecursorCache::s_maxRRSetSize))); + records.push_back(dr0); + // This one is too large and should throw exception + MRC.replace(now, power, QType::AAAA, records, {}, {}, true, authZone, boost::none); + BOOST_CHECK_EQUAL(MRC.size(), 1U); + + BOOST_CHECK_THROW((void)MRC.get(now, power, QType(QType::AAAA), MemRecursorCache::None, &retrieved, ComboAddress()), + ImmediateServFailException); + + records.resize(1); + // This one should fit again + MRC.replace(now, power, QType::AAAA, records, {}, {}, true, authZone, boost::none); + BOOST_CHECK_EQUAL(MRC.size(), 1U); + BOOST_CHECK_EQUAL(MRC.get(now, power, QType(QType::AAAA), MemRecursorCache::None, &retrieved, ComboAddress()), (ttd - now)); + BOOST_CHECK_EQUAL(retrieved.size(), 1U); +} + BOOST_AUTO_TEST_CASE(test_RecursorCacheGhost) { MemRecursorCache MRC; diff --git a/pdns/recursordist/test-syncres_cc.cc b/pdns/recursordist/test-syncres_cc.cc index 0b0751ae08..bcbeaf6864 100644 --- a/pdns/recursordist/test-syncres_cc.cc +++ b/pdns/recursordist/test-syncres_cc.cc @@ -140,6 +140,8 @@ void initSR(bool debug) } MemRecursorCache::s_maxServedStaleExtensions = 0; + MemRecursorCache::s_maxRRSetSize = 100; + MemRecursorCache::s_limitQTypeAny = true; NegCache::s_maxServedStaleExtensions = 0; g_recCache = std::make_unique(); g_negCache = std::make_unique(); diff --git a/pdns/recursordist/test-syncres_cc3.cc b/pdns/recursordist/test-syncres_cc3.cc index 77ca592c9f..3e49ebea21 100644 --- a/pdns/recursordist/test-syncres_cc3.cc +++ b/pdns/recursordist/test-syncres_cc3.cc @@ -51,27 +51,49 @@ BOOST_AUTO_TEST_CASE(test_unauth_any) const DNSName target("powerdns.com."); - sr->setAsyncCallback([target](const ComboAddress& ip, const DNSName& domain, int /* type */, bool /* doTCP */, bool /* sendRDQuery */, int /* EDNS0Level */, struct timeval* /* now */, boost::optional& /* srcmask */, boost::optional /* context */, LWResult* res, bool* /* chained */) { - if (isRootServer(ip)) { + sr->setAsyncCallback([&](const ComboAddress& address, const DNSName& domain, int type, bool /* doTCP */, bool /* sendRDQuery */, int /* EDNS0Level */, struct timeval* /* now */, boost::optional& /* srcmask */, boost::optional /* context */, LWResult* res, bool* /* chained */) { + if (isRootServer(address)) { setLWResult(res, 0, false, false, true); addRecordToLW(res, "com.", QType::NS, "a.gtld-servers.net.", DNSResourceRecord::AUTHORITY, 172800); addRecordToLW(res, "a.gtld-servers.net.", QType::A, "192.0.2.1", DNSResourceRecord::ADDITIONAL, 3600); return LWResult::Result::Success; } - else if (ip == ComboAddress("192.0.2.1:53")) { - - setLWResult(res, 0, false, false, true); - addRecordToLW(res, domain, QType::A, "192.0.2.42"); - return LWResult::Result::Success; + if (address == ComboAddress("192.0.2.1:53")) { + if (type == QType::A) { + setLWResult(res, 0, false, false, true); + addRecordToLW(res, domain, QType::A, "192.0.2.42"); + addRecordToLW(res, domain, QType::A, "192.0.2.43"); + return LWResult::Result::Success; + } + if (type == QType::AAAA) { + setLWResult(res, 0, false, false, true); + addRecordToLW(res, domain, QType::AAAA, "::1"); + return LWResult::Result::Success; + } } return LWResult::Result::Timeout; }); vector ret; - int res = sr->beginResolve(target, QType(QType::ANY), QClass::IN, ret); + int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret); + BOOST_CHECK_EQUAL(res, RCode::NoError); + BOOST_CHECK_EQUAL(ret.size(), 2U); + + ret.clear(); + res = sr->beginResolve(target, QType(QType::AAAA), QClass::IN, ret); BOOST_CHECK_EQUAL(res, RCode::NoError); BOOST_CHECK_EQUAL(ret.size(), 1U); + + ret.clear(); + MemRecursorCache::s_maxRRSetSize = 2; + BOOST_CHECK_THROW(sr->beginResolve(target, QType(QType::ANY), QClass::IN, ret), ImmediateServFailException); + + MemRecursorCache::s_limitQTypeAny = false; + ret.clear(); + res = sr->beginResolve(target, QType(QType::ANY), QClass::IN, ret); + BOOST_CHECK_EQUAL(res, RCode::NoError); + BOOST_CHECK_EQUAL(ret.size(), 3U); } static void test_no_data_f(bool qmin) diff --git a/pdns/recursordist/test-syncres_cc8.cc b/pdns/recursordist/test-syncres_cc8.cc index 231f42e66f..e16828e032 100644 --- a/pdns/recursordist/test-syncres_cc8.cc +++ b/pdns/recursordist/test-syncres_cc8.cc @@ -1581,6 +1581,20 @@ BOOST_AUTO_TEST_CASE(test_dnssec_validation_from_cache_secure_any) ret.clear(); /* third one _does_ require validation */ sr->setDNSSECValidationRequested(true); + MemRecursorCache::s_maxRRSetSize = 1; + BOOST_CHECK_THROW(sr->beginResolve(target, QType(QType::ANY), QClass::IN, ret), ImmediateServFailException); + // BOOST_CHECK_EQUAL(res, RCode::NoError); + // BOOST_CHECK_EQUAL(sr->getValidationState(), vState::Secure); + // BOOST_REQUIRE_EQUAL(ret.size(), 2U); + // for (const auto& record : ret) { + // BOOST_CHECK(record.d_type == QType::A || record.d_type == QType::AAAA || record.d_type == QType::RRSIG); + // } + BOOST_CHECK_EQUAL(queriesCount, 2U); + + ret.clear(); + /* next one _does_ require validation */ + MemRecursorCache::s_limitQTypeAny = false; + sr->setDNSSECValidationRequested(true); res = sr->beginResolve(target, QType(QType::ANY), QClass::IN, ret); BOOST_CHECK_EQUAL(res, RCode::NoError); BOOST_CHECK_EQUAL(sr->getValidationState(), vState::Secure); diff --git a/pdns/uuid-utils.cc b/pdns/uuid-utils.cc index c59e0a0d0d..301daff0bb 100644 --- a/pdns/uuid-utils.cc +++ b/pdns/uuid-utils.cc @@ -30,6 +30,7 @@ #endif /* BOOST_PENDING_INTEGER_LOG2_HPP */ #endif /* BOOST_VERSION */ +#include #include // The default of: -- 2.47.2