From f87c4aff63f32f84d8291a04bdeb887c9562613f Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Fri, 4 Mar 2016 18:12:32 +0100 Subject: [PATCH] dnsdist: Check response validity over TCP, more cache fixes - Add `unsetCache()` to remove the cache from a pool - Check the response size before caching it, and make no assumption when getting it from the cache - Check that the response is larger than sizeof(dnsheader) over TCP too - Check that the response matches the query over TCP too, because we reuse downstream connections --- pdns/README-dnsdist.md | 7 +++++++ pdns/dnsdist-cache.cc | 13 ++++++++++--- pdns/dnsdist-lua2.cc | 3 +++ pdns/dnsdist-tcp.cc | 25 ++++++++++++++++++++++++- pdns/dnsdist.cc | 4 +++- regression-tests.dnsdist/test_Basics.py | 6 ++++++ 6 files changed, 53 insertions(+), 5 deletions(-) diff --git a/pdns/README-dnsdist.md b/pdns/README-dnsdist.md index e2057e8dc4..c84799f5af 100644 --- a/pdns/README-dnsdist.md +++ b/pdns/README-dnsdist.md @@ -734,6 +734,12 @@ A reference to the cache affected to a specific pool can be retrieved with: getPool("poolname"):getCache() ``` +And removed with: + +``` +getPool("poolname"):unsetCache() +``` + Cache usage stats (hits, misses, deferred inserts and lookups, collisions) can be displayed by using the `printStats()` method: @@ -1073,6 +1079,7 @@ instantiate a server with additional parameters * ServerPool related: * `getCache()`: return the current packet cache, if any * `setCache(PacketCache)`: set the cache for this pool + * `unsetCache()`: remove the packet cache from this pool * PacketCache related: * `expunge(n)`: remove entries from the cache, leaving at most `n` entries * `expungeByName(DNSName [, qtype=ANY])`: remove entries matching the supplied DNSName and type from the cache diff --git a/pdns/dnsdist-cache.cc b/pdns/dnsdist-cache.cc index ebfc10d9cd..c8a1dd175c 100644 --- a/pdns/dnsdist-cache.cc +++ b/pdns/dnsdist-cache.cc @@ -25,7 +25,7 @@ bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, const void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, bool servFail) { - if (responseLen == 0) + if (responseLen < sizeof(dnsheader)) return; uint32_t minTTL; @@ -144,10 +144,17 @@ bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t } string dnsQName(dq.qname->toDNSString()); + const size_t dnsQNameLen = dnsQName.length(); + if (value.len < (sizeof(dnsheader) + dnsQNameLen)) { + return false; + } + memcpy(response, &queryId, sizeof(queryId)); memcpy(response + sizeof(queryId), value.value.c_str() + sizeof(queryId), sizeof(dnsheader) - sizeof(queryId)); - memcpy(response + sizeof(dnsheader), dnsQName.c_str(), dnsQName.length()); - memcpy(response + sizeof(dnsheader) + dnsQName.length(), value.value.c_str() + sizeof(dnsheader) + dnsQName.length(), value.value.length() - (sizeof(dnsheader) + dnsQName.length())); + memcpy(response + sizeof(dnsheader), dnsQName.c_str(), dnsQNameLen); + if (value.len > (sizeof(dnsheader) + dnsQNameLen)) { + memcpy(response + sizeof(dnsheader) + dnsQNameLen, value.value.c_str() + sizeof(dnsheader) + dnsQNameLen, value.len - (sizeof(dnsheader) + dnsQNameLen)); + } *responseLen = value.len; if (!stale) { age = now - value.added; diff --git a/pdns/dnsdist-lua2.cc b/pdns/dnsdist-lua2.cc index 72ed238ab1..6bd10be2c4 100644 --- a/pdns/dnsdist-lua2.cc +++ b/pdns/dnsdist-lua2.cc @@ -530,6 +530,9 @@ void moreLua(bool client) pool->packetCache = cache; }); g_lua.registerFunction("getCache", &ServerPool::getCache); + g_lua.registerFunction::*)()>("unsetCache", [](std::shared_ptr pool) { + pool->packetCache = nullptr; + }); g_lua.writeFunction("newPacketCache", [client](size_t maxEntries, boost::optional maxTTL, boost::optional minTTL, boost::optional servFailTTL, boost::optional staleTTL) { return std::make_shared(maxEntries, maxTTL ? *maxTTL : 86400, minTTL ? *minTTL : 60, servFailTTL ? *servFailTTL : 60, staleTTL ? *staleTTL : 60); diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index f4ed14d741..520b04954a 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -440,6 +440,27 @@ void* tcpClientThread(int pipefd) --ds->outstanding; outstanding = false; + if (rlen < sizeof(dnsheader)) { + break; + } + + dh = (struct dnsheader*) response; + DNSName rqname; + uint16_t rqtype, rqclass; + try { + rqname=DNSName(response, responseLen, sizeof(dnsheader), false, &rqtype, &rqclass, &consumed); + } + catch(std::exception& e) { + if(rlen > (ssize_t)sizeof(dnsheader)) + infolog("Backend %s sent us a response with id %d that did not parse: %s", ds->remote.toStringWithPort(), ntohs(dh->id), e.what()); + g_stats.nonCompliantResponses++; + break; + } + + if (rqtype != qtype || rqclass != qclass || rqname != qname) { + break; + } + if (ednsAdded) { const char * optStart = NULL; size_t optLen = 0; @@ -477,7 +498,9 @@ void* tcpClientThread(int pipefd) if(g_fixupCase) { string realname = qname.toDNSString(); - memcpy(response + sizeof(dnsheader), realname.c_str(), realname.length()); + if (responseLen >= (sizeof(dnsheader) + realname.length())) { + memcpy(response + sizeof(dnsheader), realname.c_str(), realname.length()); + } } if (packetCache && !dq.skipCache) { diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 7e86138448..688626928b 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -220,7 +220,9 @@ void* responderThread(std::shared_ptr state) if(g_fixupCase) { string realname = ids->qname.toDNSString(); - memcpy(packet+12, realname.c_str(), realname.length()); + if (responseLen >= (sizeof(dnsheader) + realname.length())) { + memcpy(packet+12, realname.c_str(), realname.length()); + } } if(dh->tc && g_truncateTC) { diff --git a/regression-tests.dnsdist/test_Basics.py b/regression-tests.dnsdist/test_Basics.py index e6cbf6b0b4..5dcc6fb79c 100644 --- a/regression-tests.dnsdist/test_Basics.py +++ b/regression-tests.dnsdist/test_Basics.py @@ -309,6 +309,12 @@ class TestBasics(DNSDistTest): receivedQuery.id = query.id self.assertEquals(query, receivedQuery) + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, unrelatedResponse) + self.assertTrue(receivedQuery) + self.assertEquals(receivedResponse, None) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + if __name__ == '__main__': unittest.main() -- 2.47.2