From: Remi Gacogne Date: Mon, 5 Nov 2018 17:02:33 +0000 (+0100) Subject: dnsdist: Prevent unlikely DO collisions in the packet cache X-Git-Tag: dnsdist-1.3.3^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d7728dafb8805a77011f384f2fe6daea3bd00061;p=thirdparty%2Fpdns.git dnsdist: Prevent unlikely DO collisions in the packet cache --- diff --git a/pdns/dnsdist-cache.cc b/pdns/dnsdist-cache.cc index 6c84357a55..b4c11df37a 100644 --- a/pdns/dnsdist-cache.cc +++ b/pdns/dnsdist-cache.cc @@ -78,9 +78,9 @@ bool DNSDistPacketCache::getClientSubnet(const char* packet, unsigned int consum return false; } -bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, const boost::optional& subnet) const +bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, bool dnssecOK, const boost::optional& subnet) const { - if (cachedValue.queryFlags != queryFlags || cachedValue.tcp != tcp || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname) { + if (cachedValue.queryFlags != queryFlags || cachedValue.dnssecOK != dnssecOK || cachedValue.tcp != tcp || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname) { return false; } @@ -113,7 +113,7 @@ void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, CacheValu CacheValue& value = it->second; bool wasExpired = value.validity <= newValue.added; - if (!wasExpired && !cachedValueMatches(value, newValue.queryFlags, newValue.qname, newValue.qtype, newValue.qclass, newValue.tcp, newValue.subnet)) { + if (!wasExpired && !cachedValueMatches(value, newValue.queryFlags, newValue.qname, newValue.qtype, newValue.qclass, newValue.tcp, newValue.dnssecOK, newValue.subnet)) { d_insertCollisions++; return; } @@ -126,7 +126,7 @@ void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, CacheValu value = newValue; } -void DNSDistPacketCache::insert(uint32_t key, const boost::optional& subnet, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional tempFailureTTL) +void DNSDistPacketCache::insert(uint32_t key, const boost::optional& subnet, uint16_t queryFlags, bool dnssecOK, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional tempFailureTTL) { if (responseLen < sizeof(dnsheader)) { return; @@ -179,6 +179,7 @@ void DNSDistPacketCache::insert(uint32_t key, const boost::optional& su newValue.validity = newValidity; newValue.added = now; newValue.tcp = tcp; + newValue.dnssecOK = dnssecOK; newValue.value = std::string(response, responseLen); newValue.subnet = subnet; @@ -200,7 +201,7 @@ void DNSDistPacketCache::insert(uint32_t key, const boost::optional& su } } -bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, boost::optional& subnet, uint32_t allowExpired, bool skipAging) +bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, boost::optional& subnet, bool dnssecOK, uint32_t allowExpired, bool skipAging) { std::string dnsQName(dq.qname->toDNSString()); uint32_t key = getKey(dnsQName, consumed, reinterpret_cast(dq.dh), dq.len, dq.tcp); @@ -247,7 +248,7 @@ bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t } /* check for collision */ - if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.dh)), *dq.qname, dq.qtype, dq.qclass, dq.tcp, subnet)) { + if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.dh)), *dq.qname, dq.qtype, dq.qclass, dq.tcp, dnssecOK, subnet)) { d_lookupCollisions++; return false; } diff --git a/pdns/dnsdist-cache.hh b/pdns/dnsdist-cache.hh index 1263c0a0e3..14902c400c 100644 --- a/pdns/dnsdist-cache.hh +++ b/pdns/dnsdist-cache.hh @@ -33,8 +33,8 @@ public: DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL=86400, uint32_t minTTL=0, uint32_t tempFailureTTL=60, uint32_t maxNegativeTTL=3600, uint32_t staleTTL=60, bool dontAge=false, uint32_t shards=1, bool deferrableInsertLock=true, bool parseECS=false); ~DNSDistPacketCache(); - void insert(uint32_t key, const boost::optional& subnet, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional tempFailureTTL); - bool get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, boost::optional& subnetOut, uint32_t allowExpired=0, bool skipAging=false); + void insert(uint32_t key, const boost::optional& subnet, uint16_t queryFlags, bool dnssecOK, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional tempFailureTTL); + bool get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, boost::optional& subnetOut, bool dnssecOK, uint32_t allowExpired=0, bool skipAging=false); void purgeExpired(size_t upTo=0); void expunge(size_t upTo=0); void expungeByName(const DNSName& name, uint16_t qtype=QType::ANY, bool suffixMatch=false); @@ -70,6 +70,7 @@ private: time_t validity{0}; uint16_t len{0}; bool tcp{false}; + bool dnssecOK{false}; }; class CacheShard @@ -95,7 +96,7 @@ private: }; static bool getClientSubnet(const char* packet, unsigned int consumed, uint16_t len, boost::optional& subnet); - bool cachedValueMatches(const CacheValue& cachedValue, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, const boost::optional& subnet) const; + bool cachedValueMatches(const CacheValue& cachedValue, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, bool dnssecOK, const boost::optional& subnet) const; uint32_t getShardIndex(uint32_t key) const; void insertLocked(CacheShard& shard, uint32_t key, CacheValue& newValue); diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 98b6e78371..131fd1a43a 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -442,11 +442,13 @@ void* tcpClientThread(int pipefd) uint32_t cacheKey = 0; boost::optional subnet; + bool dnssecOK = false; if (packetCache && !dq.skipCache) { char cachedResponse[4096]; uint16_t cachedResponseSize = sizeof cachedResponse; uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL; - if (packetCache->get(dq, (uint16_t) consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, subnet, allowExpired)) { + dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO); + if (packetCache->get(dq, (uint16_t) consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, subnet, dnssecOK, allowExpired)) { DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, (dnsheader*) cachedResponse, sizeof cachedResponse, cachedResponseSize, true, &queryRealTime); #ifdef HAVE_PROTOBUF dr.uniqueId = dq.uniqueId; @@ -630,7 +632,7 @@ void* tcpClientThread(int pipefd) } if (packetCache && !dq.skipCache) { - packetCache->insert(cacheKey, subnet, origFlags, qname, qtype, qclass, response, responseLen, true, dh->rcode, dq.tempFailureTTL); + packetCache->insert(cacheKey, subnet, origFlags, dnssecOK, qname, qtype, qclass, response, responseLen, true, dh->rcode, dq.tempFailureTTL); } #ifdef HAVE_DNSCRYPT diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index ecd402b826..115b0f707e 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -518,7 +518,7 @@ try { } if (ids->packetCache && !ids->skipCache) { - ids->packetCache->insert(ids->cacheKey, ids->subnet, ids->origFlags, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode, ids->tempFailureTTL); + ids->packetCache->insert(ids->cacheKey, ids->subnet, ids->origFlags, ids->dnssecOK, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode, ids->tempFailureTTL); } if (ids->cs && !ids->cs->muted) { @@ -1326,6 +1326,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct unsigned int consumed = 0; DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed); DNSQuestion dq(&qname, qtype, qclass, consumed, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false, &queryRealTime); + bool dnssecOK = false; if (!processQuery(holders, dq, poolname, &delayMsec, now)) { @@ -1402,7 +1403,8 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct if (packetCache && !dq.skipCache) { uint16_t cachedResponseSize = dq.size; uint32_t allowExpired = ss ? 0 : g_staleCacheEntriesTTL; - if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKey, subnet, allowExpired)) { + dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO); + if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKey, subnet, dnssecOK, allowExpired)) { DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast(query), dq.size, cachedResponseSize, false, &queryRealTime); #ifdef HAVE_PROTOBUF dr.uniqueId = dq.uniqueId; @@ -1519,6 +1521,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct ids->ednsAdded = ednsAdded; ids->ecsAdded = ecsAdded; ids->qTag = dq.qTag; + ids->dnssecOK = dnssecOK; /* If we couldn't harvest the real dest addr, still write down the listening addr since it will be useful diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 8cb7c062ca..096ff8d8e5 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -536,6 +536,7 @@ struct IDState bool ecsAdded{false}; bool skipCache{false}; bool destHarvested{false}; // if true, origDest holds the original dest addr, otherwise the listening addr + bool dnssecOK{false}; }; typedef std::unordered_map QueryCountRecords; diff --git a/pdns/test-dnsdistpacketcache_cc.cc b/pdns/test-dnsdistpacketcache_cc.cc index 281949f78e..5b1fbc0f34 100644 --- a/pdns/test-dnsdistpacketcache_cc.cc +++ b/pdns/test-dnsdistpacketcache_cc.cc @@ -24,6 +24,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { size_t counter=0; size_t skipped=0; ComboAddress remote; + bool dnssecOK = false; try { for(counter = 0; counter < 100000; ++counter) { DNSName a=DNSName(std::to_string(counter))+DNSName(" hello"); @@ -50,13 +51,13 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { boost::optional subnet; auto dh = reinterpret_cast(query.data()); DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet); + bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), dnssecOK, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none); - found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, 0, true); + found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, dnssecOK, 0, true); if (found == true) { BOOST_CHECK_EQUAL(responseBufSize, responseLen); int match = memcmp(responseBuf, response.data(), responseLen); @@ -83,7 +84,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { uint32_t key = 0; boost::optional subnet; DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet); + bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK); if (found == true) { PC.expungeByName(a); deleted++; @@ -106,7 +107,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { char response[4096]; uint16_t responseSize = sizeof(response); DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, (struct dnsheader*) query.data(), len, query.size(), false, &queryTime); - if(PC.get(dq, a.wirelength(), pwQ.getHeader()->id, response, &responseSize, &key, subnet)) { + if(PC.get(dq, a.wirelength(), pwQ.getHeader()->id, response, &responseSize, &key, subnet, dnssecOK)) { matches++; } } @@ -128,6 +129,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheServFailTTL) { gettime(&queryTime); // does not have to be accurate ("realTime") in tests ComboAddress remote; + bool dnssecOK = false; try { DNSName a = DNSName("servfail"); BOOST_CHECK_EQUAL(DNSName(a.toString()), a); @@ -152,19 +154,19 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheServFailTTL) { boost::optional subnet; auto dh = reinterpret_cast(query.data()); DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet); + bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); // Insert with failure-TTL of 0 (-> should not enter cache). - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, RCode::ServFail, boost::optional(0)); - found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, 0, true); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), dnssecOK, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, RCode::ServFail, boost::optional(0)); + found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, dnssecOK, 0, true); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); // Insert with failure-TTL non-zero (-> should enter cache). - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, RCode::ServFail, boost::optional(300)); - found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, 0, true); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), dnssecOK, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, RCode::ServFail, boost::optional(300)); + found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, dnssecOK, 0, true); BOOST_CHECK_EQUAL(found, true); BOOST_CHECK(!subnet); } @@ -182,6 +184,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNoDataTTL) { gettime(&queryTime); // does not have to be accurate ("realTime") in tests ComboAddress remote; + bool dnssecOK = false; try { DNSName name("nodata"); vector query; @@ -209,18 +212,18 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNoDataTTL) { boost::optional subnet; auto dh = reinterpret_cast(query.data()); DNSQuestion dq(&name, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, name.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet); + bool found = PC.get(dq, name.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), name, QType::A, QClass::IN, reinterpret_cast(response.data()), responseLen, false, RCode::NoError, boost::none); - found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, 0, true); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), dnssecOK, name, QType::A, QClass::IN, reinterpret_cast(response.data()), responseLen, false, RCode::NoError, boost::none); + found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, dnssecOK, 0, true); BOOST_CHECK_EQUAL(found, true); BOOST_CHECK(!subnet); sleep(2); /* it should have expired by now */ - found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, 0, true); + found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, dnssecOK, 0, true); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); } @@ -238,6 +241,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNXDomainTTL) { gettime(&queryTime); // does not have to be accurate ("realTime") in tests ComboAddress remote; + bool dnssecOK = false; try { DNSName name("nxdomain"); vector query; @@ -265,18 +269,18 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNXDomainTTL) { boost::optional subnet; auto dh = reinterpret_cast(query.data()); DNSQuestion dq(&name, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, name.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet); + bool found = PC.get(dq, name.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), name, QType::A, QClass::IN, reinterpret_cast(response.data()), responseLen, false, RCode::NXDomain, boost::none); - found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, 0, true); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), dnssecOK, name, QType::A, QClass::IN, reinterpret_cast(response.data()), responseLen, false, RCode::NXDomain, boost::none); + found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, dnssecOK, 0, true); BOOST_CHECK_EQUAL(found, true); BOOST_CHECK(!subnet); sleep(2); /* it should have expired by now */ - found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, 0, true); + found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, dnssecOK, 0, true); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); } @@ -294,6 +298,7 @@ static void *threadMangler(void* off) gettime(&queryTime); // does not have to be accurate ("realTime") in tests try { ComboAddress remote; + bool dnssecOK = false; unsigned int offset=(unsigned int)(unsigned long)off; for(unsigned int counter=0; counter < 100000; ++counter) { DNSName a=DNSName("hello ")+DNSName(std::to_string(counter+offset)); @@ -318,9 +323,9 @@ static void *threadMangler(void* off) boost::optional subnet; auto dh = reinterpret_cast(query.data()); DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); - PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet); + PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), dnssecOK, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none); } } catch(PDNSException& e) { @@ -334,6 +339,7 @@ AtomicCounter g_missing; static void *threadReader(void* off) { + bool dnssecOK = false; struct timespec queryTime; gettime(&queryTime); // does not have to be accurate ("realTime") in tests try @@ -352,7 +358,7 @@ static void *threadReader(void* off) uint32_t key = 0; boost::optional subnet; DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet); + bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK); if (!found) { g_missing++; } @@ -402,6 +408,7 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { uint32_t key; uint32_t secondKey; boost::optional subnetOut; + bool dnssecOK = false; /* lookup for a query with an ECS value of 10.0.118.46/32, insert a corresponding response */ @@ -423,7 +430,7 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { struct timespec queryTime; gettime(&queryTime); DNSQuestion dq(&qname, QType::AAAA, QClass::IN, 0, &remote, &remote, pwQ.getHeader(), query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &key, subnetOut); + bool found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &key, subnetOut, dnssecOK); BOOST_CHECK_EQUAL(found, false); BOOST_REQUIRE(subnetOut); BOOST_CHECK_EQUAL(subnetOut->toString(), opt.source.toString()); @@ -439,10 +446,10 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { pwR.addOpt(512, 0, 0, ednsOptions); pwR.commit(); - PC.insert(key, subnetOut, *(getFlagsFromDNSHeader(pwR.getHeader())), qname, qtype, QClass::IN, reinterpret_cast(response.data()), response.size(), false, RCode::NoError, boost::none); + PC.insert(key, subnetOut, *(getFlagsFromDNSHeader(pwR.getHeader())), dnssecOK, qname, qtype, QClass::IN, reinterpret_cast(response.data()), response.size(), false, RCode::NoError, boost::none); BOOST_CHECK_EQUAL(PC.getSize(), 1); - found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &key, subnetOut); + found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &key, subnetOut, dnssecOK); BOOST_CHECK_EQUAL(found, true); BOOST_REQUIRE(subnetOut); BOOST_CHECK_EQUAL(subnetOut->toString(), opt.source.toString()); @@ -468,7 +475,7 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { struct timespec queryTime; gettime(&queryTime); DNSQuestion dq(&qname, QType::AAAA, QClass::IN, 0, &remote, &remote, pwQ.getHeader(), query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &secondKey, subnetOut); + bool found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &secondKey, subnetOut, dnssecOK); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK_EQUAL(secondKey, key); BOOST_REQUIRE(subnetOut); @@ -477,4 +484,58 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { } } +BOOST_AUTO_TEST_CASE(test_PCDNSSECCollision) { + const size_t maxEntries = 150000; + DNSDistPacketCache PC(maxEntries, 86400, 1, 60, 3600, 60, false, 1, true, true); + BOOST_CHECK_EQUAL(PC.getSize(), 0); + + DNSName qname("www.powerdns.com."); + uint16_t qtype = QType::AAAA; + uint16_t qid = 0x42; + uint32_t key; + boost::optional subnetOut; + + /* lookup for a query with DNSSEC OK, + insert a corresponding response with DO set, + check that it doesn't match without DO, but does with it */ + { + vector query; + DNSPacketWriter pwQ(query, qname, qtype, QClass::IN, 0); + pwQ.getHeader()->rd = 1; + pwQ.getHeader()->id = qid; + pwQ.addOpt(512, 0, EDNS_HEADER_FLAG_DO); + pwQ.commit(); + + char responseBuf[4096]; + uint16_t responseBufSize = sizeof(responseBuf); + ComboAddress remote("192.0.2.1"); + struct timespec queryTime; + gettime(&queryTime); + DNSQuestion dq(&qname, QType::AAAA, QClass::IN, 0, &remote, &remote, pwQ.getHeader(), query.size(), query.size(), false, &queryTime); + bool found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &key, subnetOut, true); + BOOST_CHECK_EQUAL(found, false); + + vector response; + DNSPacketWriter pwR(response, qname, qtype, QClass::IN, 0); + pwR.getHeader()->rd = 1; + pwR.getHeader()->id = qid; + pwR.startRecord(qname, qtype, 100, QClass::IN, DNSResourceRecord::ANSWER); + ComboAddress v6("::1"); + pwR.xfrCAWithoutPort(6, v6); + pwR.commit(); + pwR.addOpt(512, 0, EDNS_HEADER_FLAG_DO); + pwR.commit(); + + PC.insert(key, subnetOut, *(getFlagsFromDNSHeader(pwR.getHeader())), /* DNSSEC OK is set */ true, qname, qtype, QClass::IN, reinterpret_cast(response.data()), response.size(), false, RCode::NoError, boost::none); + BOOST_CHECK_EQUAL(PC.getSize(), 1); + + found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &key, subnetOut, false); + BOOST_CHECK_EQUAL(found, false); + + found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &key, subnetOut, true); + BOOST_CHECK_EQUAL(found, true); + } + +} + BOOST_AUTO_TEST_SUITE_END()