From: Remi Gacogne Date: Fri, 23 Dec 2016 09:25:45 +0000 (+0100) Subject: dnsdist: Handle Refused as ServFail in the packet cache X-Git-Tag: dnsdist-1.1.0^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2714396eed4a84d5e73e2769d5c00e34116241a8;p=thirdparty%2Fpdns.git dnsdist: Handle Refused as ServFail in the packet cache --- diff --git a/pdns/README-dnsdist.md b/pdns/README-dnsdist.md index e899c0f8d2..6b47f3e747 100644 --- a/pdns/README-dnsdist.md +++ b/pdns/README-dnsdist.md @@ -855,8 +855,8 @@ The first parameter is the maximum number of entries stored in the cache, and is only one required. All the others parameters are optional and in seconds. The second one is the maximum lifetime of an entry in the cache, the third one is the minimum TTL an entry should have to be considered for insertion in the cache, -the fourth one is the TTL used for a Server Failure response. The last one is the -TTL that will be used when a stale cache entry is returned. +the fourth one is the TTL used for a Server Failure or a Refused response. The last +one is the TTL that will be used when a stale cache entry is returned. The `setStaleCacheEntriesTTL(n)` directive can be used to allow `dnsdist` to use expired entries from the cache when no backend is available. Only entries that have @@ -1428,7 +1428,7 @@ instantiate a server with additional parameters * `expunge(n)`: remove entries from the cache, leaving at most `n` entries * `expungeByName(DNSName [, qtype=ANY])`: remove entries matching the supplied DNSName and type from the cache * `isFull()`: return true if the cache has reached the maximum number of entries - * `newPacketCache(maxEntries[, maxTTL=86400, minTTL=0, servFailTTL=60, staleTTL=60])`: return a new PacketCache + * `newPacketCache(maxEntries[, maxTTL=86400, minTTL=0, temporaryFailureTTL=60, staleTTL=60])`: return a new PacketCache * `printStats()`: print the cache stats (hits, misses, deferred lookups and deferred inserts) * `purgeExpired(n)`: remove expired entries from the cache until there is at most `n` entries remaining in the cache * `toString()`: return the number of entries in the Packet Cache, and the maximum number of entries diff --git a/pdns/dnsdist-cache.cc b/pdns/dnsdist-cache.cc index 063aa61141..e36fa5d451 100644 --- a/pdns/dnsdist-cache.cc +++ b/pdns/dnsdist-cache.cc @@ -24,7 +24,7 @@ #include "dnsparser.hh" #include "dnsdist-cache.hh" -DNSDistPacketCache::DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL, uint32_t minTTL, uint32_t servFailTTL, uint32_t staleTTL): d_maxEntries(maxEntries), d_maxTTL(maxTTL), d_servFailTTL(servFailTTL), d_minTTL(minTTL), d_staleTTL(staleTTL) +DNSDistPacketCache::DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL, uint32_t minTTL, uint32_t tempFailureTTL, uint32_t staleTTL): d_maxEntries(maxEntries), d_maxTTL(maxTTL), d_tempFailureTTL(tempFailureTTL), d_minTTL(minTTL), d_staleTTL(staleTTL) { pthread_rwlock_init(&d_lock, 0); /* we reserve maxEntries + 1 to avoid rehashing from occuring @@ -44,15 +44,15 @@ bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, const return true; } -void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, bool servFail) +void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode) { if (responseLen < sizeof(dnsheader)) return; uint32_t minTTL; - if (servFail) { - minTTL = d_servFailTTL; + if (rcode == RCode::ServFail || rcode == RCode::Refused) { + minTTL = d_tempFailureTTL; } else { minTTL = getMinTTL(response, responseLen); diff --git a/pdns/dnsdist-cache.hh b/pdns/dnsdist-cache.hh index f8b3c1cc9e..10164c7c64 100644 --- a/pdns/dnsdist-cache.hh +++ b/pdns/dnsdist-cache.hh @@ -30,10 +30,10 @@ struct DNSQuestion; class DNSDistPacketCache : boost::noncopyable { public: - DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL=86400, uint32_t minTTL=0, uint32_t servFailTTL=60, uint32_t staleTTL=60); + DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL=86400, uint32_t minTTL=0, uint32_t tempFailureTTL=60, uint32_t staleTTL=60); ~DNSDistPacketCache(); - void insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, bool servFail=false); + void insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode); bool get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, uint32_t allowExpired=0, bool skipAging=false); void purgeExpired(size_t upTo=0); void expunge(size_t upTo=0); @@ -82,7 +82,7 @@ private: std::atomic d_ttlTooShorts{0}; size_t d_maxEntries; uint32_t d_maxTTL; - uint32_t d_servFailTTL; + uint32_t d_tempFailureTTL; uint32_t d_minTTL; uint32_t d_staleTTL; }; diff --git a/pdns/dnsdist-lua2.cc b/pdns/dnsdist-lua2.cc index a5e27495e0..416d4554ce 100644 --- a/pdns/dnsdist-lua2.cc +++ b/pdns/dnsdist-lua2.cc @@ -666,8 +666,8 @@ void moreLua(bool client) } }); - g_lua.writeFunction("newPacketCache", [client](size_t maxEntries, boost::optional maxTTL, boost::optional minTTL, boost::optional servFailTTL, boost::optional staleTTL) { - return std::make_shared(maxEntries, maxTTL ? *maxTTL : 86400, minTTL ? *minTTL : 0, servFailTTL ? *servFailTTL : 60, staleTTL ? *staleTTL : 60); + g_lua.writeFunction("newPacketCache", [client](size_t maxEntries, boost::optional maxTTL, boost::optional minTTL, boost::optional tempFailTTL, boost::optional staleTTL) { + return std::make_shared(maxEntries, maxTTL ? *maxTTL : 86400, minTTL ? *minTTL : 0, tempFailTTL ? *tempFailTTL : 60, staleTTL ? *staleTTL : 60); }); g_lua.registerFunction("toString", &DNSDistPacketCache::toString); g_lua.registerFunction("isFull", &DNSDistPacketCache::isFull); diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 69ea938785..5358c49629 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -497,7 +497,7 @@ void* tcpClientThread(int pipefd) } if (packetCache && !dq.skipCache) { - packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true, dh->rcode == RCode::ServFail); + packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true, dh->rcode); } #ifdef HAVE_DNSCRYPT diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 9733359b93..a588e2c9a6 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -439,7 +439,7 @@ try { } if (ids->packetCache && !ids->skipCache) { - ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode == RCode::ServFail); + ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode); } #ifdef HAVE_DNSCRYPT diff --git a/pdns/test-dnsdistpacketcache_cc.cc b/pdns/test-dnsdistpacketcache_cc.cc index 825dc4a812..712ae5b2fc 100644 --- a/pdns/test-dnsdistpacketcache_cc.cc +++ b/pdns/test-dnsdistpacketcache_cc.cc @@ -45,7 +45,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key); BOOST_CHECK_EQUAL(found, false); - PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false); + PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0); found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, 0, true); if (found == true) { @@ -135,7 +135,7 @@ static void *threadMangler(void* a) DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false); PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key); - PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false); + PC.insert(key, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0); } } catch(PDNSException& e) { diff --git a/regression-tests.dnsdist/test_Caching.py b/regression-tests.dnsdist/test_Caching.py index a016b6cc31..46d3329ae8 100644 --- a/regression-tests.dnsdist/test_Caching.py +++ b/regression-tests.dnsdist/test_Caching.py @@ -936,3 +936,139 @@ class TestCachingLongTTL(DNSDistTest): total += self._responsesCounter[key] self.assertEquals(total, misses) + +class TestCachingFailureTTL(DNSDistTest): + + _failureCacheTTL = 2 + _config_params = ['_failureCacheTTL', '_testServerPort'] + _config_template = """ + pc = newPacketCache(1000, 86400, 0, %d, 60) + getPool(""):setCache(pc) + newServer{address="127.0.0.1:%s"} + """ + def testCacheServFailTTL(self): + """ + Cache: ServFail TTL + + """ + misses = 0 + name = 'servfail.failure.cache.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + response.set_rcode(dns.rcode.SERVFAIL) + + # Miss + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + misses += 1 + + # next queries should hit the cache + (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) + self.assertEquals(receivedResponse, response) + + time.sleep(self._failureCacheTTL + 1) + + # we should not have cached for longer than failure cache + # so it should be a miss + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + misses += 1 + + total = 0 + for key in self._responsesCounter: + total += self._responsesCounter[key] + + self.assertEquals(total, misses) + + def testCacheRefusedTTL(self): + """ + Cache: Refused TTL + + """ + misses = 0 + name = 'refused.failure.cache.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + response.set_rcode(dns.rcode.REFUSED) + + # Miss + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + misses += 1 + + # next queries should hit the cache + (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) + self.assertEquals(receivedResponse, response) + + time.sleep(self._failureCacheTTL + 1) + + # we should not have cached for longer than failure cache + # so it should be a miss + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + misses += 1 + + total = 0 + for key in self._responsesCounter: + total += self._responsesCounter[key] + + self.assertEquals(total, misses) + + def testCacheHeaderOnlyRefusedTTL(self): + """ + Cache: Header-Only Refused TTL + + """ + misses = 0 + name = 'header-only-refused.failure.cache.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + response.set_rcode(dns.rcode.REFUSED) + response.question = [] + + # Miss + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + misses += 1 + + # next queries should hit the cache + (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) + self.assertEquals(receivedResponse, response) + + time.sleep(self._failureCacheTTL + 1) + + # we should not have cached for longer than failure cache + # so it should be a miss + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + misses += 1 + + total = 0 + for key in self._responsesCounter: + total += self._responsesCounter[key] + + self.assertEquals(total, misses)