]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Refactor duplicated query handling code (UDP/TCP)
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 15 Mar 2016 16:57:19 +0000 (17:57 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 15 Mar 2016 16:57:19 +0000 (17:57 +0100)
Merge the UDP/TCP code handling:

* dynamic blocks
* blockfilter
* rules

This fixes an issue with DNSCrypt, where cached responses were not
being encrypted.

pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
regression-tests.dnsdist/test_DNSCrypt.py

index 1df6b06d24ca97c3543160e3a8b7b49cf5b768f0..6c742b4e959584cdbe3d10ac438c003cf92bf917 100644 (file)
@@ -167,7 +167,6 @@ void* tcpClientThread(int pipefd)
   /* we get launched with a pipe on which we receive file descriptors from clients that we own
      from that point on */
      
-  typedef std::function<bool(const DNSQuestion*)> blockfilter_t;
   bool outstanding = false;
   blockfilter_t blockFilter = 0;
   
@@ -215,6 +214,9 @@ void* tcpClientThread(int pipefd)
         if(!getNonBlockingMsgLen(ci.fd, &qlen, g_tcpRecvTimeout))
           break;
 
+        ci.cs->queries++;
+        g_stats.queries++;
+
         if (qlen < sizeof(dnsheader)) {
           g_stats.nonCompliantQueries++;
           break;
@@ -227,6 +229,7 @@ void* tcpClientThread(int pipefd)
         char queryBuffer[querySize];
         const char* query = queryBuffer;
         readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout);
+
 #ifdef HAVE_DNSCRYPT
         std::shared_ptr<DnsCryptQuery> dnsCryptQuery = 0;
 
@@ -257,92 +260,24 @@ void* tcpClientThread(int pipefd)
           goto drop;
         }
 
+        if (dh->rd) {
+          g_stats.rdQueries++;
+        }
+
+       const uint16_t* flags = getFlagsFromDNSHeader(dh);
+       uint16_t origFlags = *flags;
        uint16_t qtype, qclass;
        unsigned int consumed = 0;
        DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
        DNSQuestion dq(&qname, qtype, qclass, &ci.cs->local, &ci.remote, (dnsheader*)query, querySize, qlen, true);
-       string ruleresult;
-       const uint16_t * flags = getFlagsFromDNSHeader(dq.dh);
-       uint16_t origFlags = *flags;
+
+       string poolname;
+       int delayMsec=0;
        struct timespec now;
        clock_gettime(CLOCK_MONOTONIC, &now);
 
-       {
-         WriteLock wl(&g_rings.queryLock);
-         g_rings.queryRing.push_back({now,ci.remote,qname,dq.len,dq.qtype,*dq.dh});
-       }
-
-       g_stats.queries++;
-       if (ci.cs) {
-         ci.cs->queries++;
-       }
-
-       if(auto got=localDynBlockNMG->lookup(ci.remote)) {
-         if(now < got->second.until) {
-           vinfolog("Query from %s dropped because of dynamic block", ci.remote.toStringWithPort());
-           g_stats.dynBlocked++;
-           got->second.blocks++;
-           goto drop;
-         }
-       }
-
-        if (dq.dh->rd) {
-          g_stats.rdQueries++;
-        }
-
-        if(blockFilter) {
-         std::lock_guard<std::mutex> lock(g_luamutex);
-       
-         if(blockFilter(&dq)) {
-           g_stats.blockFilter++;
-           goto drop;
-          }
-          if(dq.dh->tc && dq.dh->qr) { // don't truncate on TCP/IP!
-            dq.dh->tc=false;        // maybe we should just pass blockFilter the TCP status
-            dq.dh->qr=false;
-          }
-        }
-
-        bool done=false;
-       DNSAction::Action action=DNSAction::Action::None;
-       for(const auto& lr : *localRulactions) {
-         if(lr.first->matches(&dq)) {
-            lr.first->d_matches++;
-            action=(*lr.second)(&dq, &ruleresult);
-            switch(action) {
-            case DNSAction::Action::Allow:
-              done = true;
-              break;
-            case DNSAction::Action::Drop:
-              g_stats.ruleDrop++;
-              goto drop;
-              break;
-            case DNSAction::Action::Nxdomain:
-              dq.dh->rcode = RCode::NXDomain;
-              dq.dh->qr=true;
-              g_stats.ruleNXDomain++;
-              done = true;
-              break;
-            case DNSAction::Action::Spoof:
-              spoofResponseFromString(dq, ruleresult);
-              done = true;
-              break;
-            case DNSAction::Action::HeaderModify:
-              done = true;
-              break;
-            case DNSAction::Action::Pool:
-              poolname=ruleresult;
-              done = true;
-              break;
-            /* non-terminal actions follow */
-            case DNSAction::Action::Delay:
-            case DNSAction::Action::None:
-              break;
-            }
-           if(done) {
-             break;
-           }
-         }
+       if (!processQuery(localDynBlockNMG, localRulactions, blockFilter, dq, ci.remote, poolname, &delayMsec, now)) {
+         goto drop;
        }
 
        if(dq.dh->qr) { // something turned it into a response
index cb42bd76b90f6fc7d1f9e679d2639502f0e4334e..c510feadf2551b2d371e6cd44aa5a7a86f433d25 100644 (file)
@@ -334,7 +334,7 @@ void* responderThread(std::shared_ptr<DownstreamState> state)
     --state->outstanding;  // you'd think an attacker could game this, but we're using connected socket
 
     if(dh->tc && g_truncateTC) {
-      truncateTC(response, (uint16_t*) &responseLen);
+      truncateTC(response, &responseLen);
     }
 
     dh->id = ids->origID;
@@ -648,6 +648,76 @@ void spoofResponseFromString(DNSQuestion& dq, const string& spoofContent)
   }
 }
 
+bool processQuery(LocalStateHolder<NetmaskTree<DynBlock> >& localDynBlock, LocalStateHolder<vector<pair<std::shared_ptr<DNSRule>, std::shared_ptr<DNSAction> > > >& localRulactions, blockfilter_t blockFilter, DNSQuestion& dq, const ComboAddress& remote, string& poolname, int* delayMsec, const struct timespec& now)
+{
+  {
+    WriteLock wl(&g_rings.queryLock);
+    g_rings.queryRing.push_back({now,remote,*dq.qname,dq.len,dq.qtype,*dq.dh});
+  }
+
+  if(auto got=localDynBlock->lookup(remote)) {
+    if(now < got->second.until) {
+      vinfolog("Query from %s dropped because of dynamic block", remote.toStringWithPort());
+      g_stats.dynBlocked++;
+      got->second.blocks++;
+      return false;
+    }
+  }
+
+  if(blockFilter) {
+    std::lock_guard<std::mutex> lock(g_luamutex);
+
+    if(blockFilter(&dq)) {
+      g_stats.blockFilter++;
+      return false;
+    }
+  }
+
+  DNSAction::Action action=DNSAction::Action::None;
+  string ruleresult;
+  for(const auto& lr : *localRulactions) {
+    if(lr.first->matches(&dq)) {
+      lr.first->d_matches++;
+      action=(*lr.second)(&dq, &ruleresult);
+
+      switch(action) {
+      case DNSAction::Action::Allow:
+        return true;
+        break;
+      case DNSAction::Action::Drop:
+        g_stats.ruleDrop++;
+        return false;
+        break;
+      case DNSAction::Action::Nxdomain:
+        dq.dh->rcode = RCode::NXDomain;
+        dq.dh->qr=true;
+        g_stats.ruleNXDomain++;
+        return true;
+        break;
+      case DNSAction::Action::Spoof:
+        spoofResponseFromString(dq, ruleresult);
+        return true;
+        break;
+      case DNSAction::Action::HeaderModify:
+        return true;
+        break;
+      case DNSAction::Action::Pool:
+        poolname=ruleresult;
+        return true;
+        break;
+        /* non-terminal actions follow */
+      case DNSAction::Action::Delay:
+        *delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
+        break;
+      case DNSAction::Action::None:
+        break;
+      }
+    }
+  }
+
+  return true;
+}
+
 static ssize_t udpClientSendRequestToBackend(DownstreamState* ss, const int sd, const char* request, const size_t requestLen)
 {
   if (ss->sourceItf == 0) {
@@ -672,7 +742,6 @@ try
   string largerQuery;
   uint16_t qtype, qclass;
 
-  typedef std::function<bool(DNSQuestion*)> blockfilter_t;
   blockfilter_t blockFilter = 0;
   {
     std::lock_guard<std::mutex> lock(g_luamutex);
@@ -703,6 +772,12 @@ try
       ssize_t ret = recvmsg(cs->udpFD, &msgh, 0);
       queryId = 0;
 
+      if(!acl->match(remote)) {
+       vinfolog("Query from %s dropped because of ACL", remote.toStringWithPort());
+       g_stats.aclDrops++;
+       continue;
+      }
+
       cs->queries++;
       g_stats.queries++;
 
@@ -718,12 +793,6 @@ try
         continue;
       }
 
-      if(!acl->match(remote)) {
-       vinfolog("Query from %s dropped because of ACL", remote.toStringWithPort());
-       g_stats.aclDrops++;
-       continue;
-      }
-
       uint16_t len = (uint16_t) ret;
 #ifdef HAVE_DNSCRYPT
       if (cs->dnscryptCtx) {
@@ -739,11 +808,7 @@ try
             if(!HarvestDestinationAddress(&msgh, &dest)) {
               dest.sin4.sin_family = 0;
             }
-            sendUDPResponse(cs->udpFD, reinterpret_cast<char*>(response.data()), response.size(), response.size(),
-#ifdef HAVE_DNSCRYPT
-                            nullptr,
-#endif
-                            0, dest, remote);
+            sendUDPResponse(cs->udpFD, reinterpret_cast<char*>(response.data()), response.size(), response.size(), nullptr, 0, dest, remote);
           }
           continue;
         }
@@ -772,83 +837,15 @@ try
       const uint16_t origFlags = *flags;
       unsigned int consumed = 0;
       DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-
       DNSQuestion dq(&qname, qtype, qclass, &cs->local, &remote, dh, sizeof(packet), len, false);
 
-      struct timespec now;
-      clock_gettime(CLOCK_MONOTONIC, &now);
-      {
-        WriteLock wl(&g_rings.queryLock);
-        g_rings.queryRing.push_back({now,remote,qname,dq.len,dq.qtype,*dq.dh});
-      }
-
-      if(auto got=localDynBlock->lookup(remote)) {
-       if(now < got->second.until) {
-         vinfolog("Query from %s dropped because of dynamic block", remote.toStringWithPort());
-         g_stats.dynBlocked++;
-         got->second.blocks++;
-         continue;
-       }
-      }
-
-      if(blockFilter) {
-       std::lock_guard<std::mutex> lock(g_luamutex);
-       
-       if(blockFilter(&dq)) {
-         g_stats.blockFilter++;
-         continue;
-       }
-      }
-
-      DNSAction::Action action=DNSAction::Action::None;
-      string ruleresult;
       string poolname;
       int delayMsec=0;
-      bool done=false;
-      for(const auto& lr : *localRulactions) {
-        if(lr.first->matches(&dq)) {
-          lr.first->d_matches++;
-          action=(*lr.second)(&dq, &ruleresult);
-
-          switch(action) {
-          case DNSAction::Action::Allow:
-            done = true;
-            break;
-          case DNSAction::Action::Drop:
-            g_stats.ruleDrop++;
-            done = true;
-            break;
-          case DNSAction::Action::Nxdomain:
-            dq.dh->rcode = RCode::NXDomain;
-            dq.dh->qr=true;
-            g_stats.ruleNXDomain++;
-            done = true;
-            break;
-          case DNSAction::Action::Spoof:
-            spoofResponseFromString(dq, ruleresult);
-            done = true;
-            break;
-          case DNSAction::Action::HeaderModify:
-            done = true;
-            break;
-          case DNSAction::Action::Pool:
-            poolname=ruleresult;
-            done = true;
-            break;
-          /* non-terminal actions follow */
-          case DNSAction::Action::Delay:
-            delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
-            break;
-          case DNSAction::Action::None:
-           break;
-          }
-          if (done) {
-            break;
-          }
-        }
-      }
+      struct timespec now;
+      clock_gettime(CLOCK_MONOTONIC, &now);
 
-      if (action == DNSAction::Action::Drop) {
+      if (!processQuery(localDynBlock, localRulactions, blockFilter, dq, remote, poolname, &delayMsec, now))
+      {
         continue;
       }
 
index 946ceed3bdc8f25c0110700d031d21de8274e781..fe43e83cf1304b49e8f48c1a9da13dbcf53e6974 100644 (file)
@@ -381,6 +381,7 @@ struct DNSQuestion
   bool skipCache{false};
 };
 
+typedef std::function<bool(const DNSQuestion*)> blockfilter_t;
 template <class T> using NumberedVector = std::vector<std::pair<unsigned int, T> >;
 
 void* responderThread(std::shared_ptr<DownstreamState> state);
@@ -507,6 +508,7 @@ bool getLuaNoSideEffect(); // set if there were only explicit declarations of _n
 void resetLuaSideEffect(); // reset to indeterminate state
 
 bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote);
+bool processQuery(LocalStateHolder<NetmaskTree<DynBlock> >& localDynBlock, LocalStateHolder<vector<pair<std::shared_ptr<DNSRule>, std::shared_ptr<DNSAction> > > >& localRulactions, blockfilter_t blockFilter, DNSQuestion& dq, const ComboAddress& remote, string& poolname, int* delayMsec, const struct timespec& now);
 bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded,
 #ifdef HAVE_DNSCRYPT
                    std::shared_ptr<DnsCryptQuery> dnsCryptQuery,
index abf1da5ca4fd0bc95a18e1c30a997a5fa74b83e8..c20bdcdc324888f56db8ac7919f108126f5763b1 100644 (file)
@@ -128,6 +128,7 @@ class TestDNSCryptWithCache(DNSDistTest):
         """
         DNSCrypt: encrypted A query served from cache
         """
+        misses = 0
         client = dnscrypt.DNSCryptClient(self._providerName, self._providerFingerprint, "127.0.0.1", 8443)
         name = 'cacheda.dnscrypt.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
@@ -152,6 +153,7 @@ class TestDNSCryptWithCache(DNSDistTest):
         receivedQuery.id = query.id
         self.assertEquals(query, receivedQuery)
         self.assertEquals(response, receivedResponse)
+        misses += 1
 
         # second query should get a cached response
         data = client.query(query.to_wire())
@@ -163,3 +165,7 @@ class TestDNSCryptWithCache(DNSDistTest):
         self.assertEquals(receivedQuery, None)
         self.assertTrue(receivedResponse)
         self.assertEquals(response, receivedResponse)
+        total = 0
+        for key in self._responsesCounter:
+            total += self._responsesCounter[key]
+        self.assertEquals(total, misses)