]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Check response validity over TCP, more cache fixes 3509/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 4 Mar 2016 17:12:32 +0000 (18:12 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 4 Mar 2016 17:12:32 +0000 (18:12 +0100)
- Add `unsetCache()` to remove the cache from a pool
- Check the response size before caching it, and make no
assumption when getting it from the cache
- Check that the response is larger than sizeof(dnsheader) over
TCP too
- Check that the response matches the query over TCP too, because
we reuse downstream connections

pdns/README-dnsdist.md
pdns/dnsdist-cache.cc
pdns/dnsdist-lua2.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
regression-tests.dnsdist/test_Basics.py

index e2057e8dc427c64e1fc21df0591f07ca0156e228..c84799f5af7b50f29ef5008f129069a407e38cd7 100644 (file)
@@ -734,6 +734,12 @@ A reference to the cache affected to a specific pool can be retrieved with:
 getPool("poolname"):getCache()
 ```
 
+And removed with:
+
+```
+getPool("poolname"):unsetCache()
+```
+
 Cache usage stats (hits, misses, deferred inserts and lookups, collisions)
 can be displayed by using the `printStats()` method:
 
@@ -1073,6 +1079,7 @@ instantiate a server with additional parameters
  * ServerPool related:
     * `getCache()`: return the current packet cache, if any
     * `setCache(PacketCache)`: set the cache for this pool
+    * `unsetCache()`: remove the packet cache from this pool
  * PacketCache related:
     * `expunge(n)`: remove entries from the cache, leaving at most `n` entries
     * `expungeByName(DNSName [, qtype=ANY])`: remove entries matching the supplied DNSName and type from the cache
index ebfc10d9cd0e36ba2c7edc379609817e50544edc..c8a1dd175ccc66a362f96786e366ae515ba14cee 100644 (file)
@@ -25,7 +25,7 @@ bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, const
 
 void DNSDistPacketCache::insert(uint32_t key, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, bool servFail)
 {
-  if (responseLen == 0)
+  if (responseLen < sizeof(dnsheader))
     return;
 
   uint32_t minTTL;
@@ -144,10 +144,17 @@ bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t
     }
 
     string dnsQName(dq.qname->toDNSString());
+    const size_t dnsQNameLen = dnsQName.length();
+    if (value.len < (sizeof(dnsheader) + dnsQNameLen)) {
+      return false;
+    }
+
     memcpy(response, &queryId, sizeof(queryId));
     memcpy(response + sizeof(queryId), value.value.c_str() + sizeof(queryId), sizeof(dnsheader) - sizeof(queryId));
-    memcpy(response + sizeof(dnsheader), dnsQName.c_str(), dnsQName.length());
-    memcpy(response + sizeof(dnsheader) + dnsQName.length(), value.value.c_str() + sizeof(dnsheader) + dnsQName.length(), value.value.length() - (sizeof(dnsheader) + dnsQName.length()));
+    memcpy(response + sizeof(dnsheader), dnsQName.c_str(), dnsQNameLen);
+    if (value.len > (sizeof(dnsheader) + dnsQNameLen)) {
+      memcpy(response + sizeof(dnsheader) + dnsQNameLen, value.value.c_str() + sizeof(dnsheader) + dnsQNameLen, value.len - (sizeof(dnsheader) + dnsQNameLen));
+    }
     *responseLen = value.len;
     if (!stale) {
       age = now - value.added;
index 72ed238ab1192aedea17a73792d70ef0de37f72b..6bd10be2c4e89d5d7c22e9186353ab80215834ca 100644 (file)
@@ -530,6 +530,9 @@ void moreLua(bool client)
         pool->packetCache = cache;
     });
     g_lua.registerFunction("getCache", &ServerPool::getCache);
+    g_lua.registerFunction<void(std::shared_ptr<ServerPool>::*)()>("unsetCache", [](std::shared_ptr<ServerPool> pool) {
+        pool->packetCache = nullptr;
+    });
 
     g_lua.writeFunction("newPacketCache", [client](size_t maxEntries, boost::optional<uint32_t> maxTTL, boost::optional<uint32_t> minTTL, boost::optional<uint32_t> servFailTTL, boost::optional<uint32_t> staleTTL) {
         return std::make_shared<DNSDistPacketCache>(maxEntries, maxTTL ? *maxTTL : 86400, minTTL ? *minTTL : 60, servFailTTL ? *servFailTTL : 60, staleTTL ? *staleTTL : 60);
index f4ed14d7416b13b1f33a94c6698cb708f2567106..520b04954af1a0808ca7935884fb92f15d98f817 100644 (file)
@@ -440,6 +440,27 @@ void* tcpClientThread(int pipefd)
         --ds->outstanding;
         outstanding = false;
 
+        if (rlen < sizeof(dnsheader)) {
+          break;
+        }
+
+        dh = (struct dnsheader*) response;
+        DNSName rqname;
+        uint16_t rqtype, rqclass;
+        try {
+          rqname=DNSName(response, responseLen, sizeof(dnsheader), false, &rqtype, &rqclass, &consumed);
+        }
+        catch(std::exception& e) {
+          if(rlen > (ssize_t)sizeof(dnsheader))
+            infolog("Backend %s sent us a response with id %d that did not parse: %s", ds->remote.toStringWithPort(), ntohs(dh->id), e.what());
+          g_stats.nonCompliantResponses++;
+          break;
+        }
+
+        if (rqtype != qtype || rqclass != qclass || rqname != qname) {
+          break;
+        }
+
         if (ednsAdded) {
           const char * optStart = NULL;
           size_t optLen = 0;
@@ -477,7 +498,9 @@ void* tcpClientThread(int pipefd)
 
        if(g_fixupCase) {
          string realname = qname.toDNSString();
-         memcpy(response + sizeof(dnsheader), realname.c_str(), realname.length());
+         if (responseLen >= (sizeof(dnsheader) + realname.length())) {
+           memcpy(response + sizeof(dnsheader), realname.c_str(), realname.length());
+         }
        }
 
        if (packetCache && !dq.skipCache) {
index 7e86138448e98f0b3962682371df7364e1f05f90..688626928bbd7e78ae1a51d21abc55fb070a1138 100644 (file)
@@ -220,7 +220,9 @@ void* responderThread(std::shared_ptr<DownstreamState> state)
 
     if(g_fixupCase) {
       string realname = ids->qname.toDNSString();
-      memcpy(packet+12, realname.c_str(), realname.length());
+      if (responseLen >= (sizeof(dnsheader) + realname.length())) {
+        memcpy(packet+12, realname.c_str(), realname.length());
+      }
     }
 
     if(dh->tc && g_truncateTC) {
index e6cbf6b0b43f733ad06732f1a780b19602008a5f..5dcc6fb79c097b7da60b0e64561356acd78c48bd 100644 (file)
@@ -309,6 +309,12 @@ class TestBasics(DNSDistTest):
         receivedQuery.id = query.id
         self.assertEquals(query, receivedQuery)
 
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, unrelatedResponse)
+        self.assertTrue(receivedQuery)
+        self.assertEquals(receivedResponse, None)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+
 
 if __name__ == '__main__':
     unittest.main()