]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Split DNSCrypt encryption from sendResponse. Fix flags. 3582/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 16 Mar 2016 11:15:09 +0000 (12:15 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 16 Mar 2016 11:15:09 +0000 (12:15 +0100)
Splitting DNSCrypt encryption from the sendResponse functions to
avoid the ugly #ifdef'ed definitions.
Flags were not correctly restored for self-generated responses.

pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
regression-tests.dnsdist/test_Advanced.py

index 6c742b4e959584cdbe3d10ac438c003cf92bf917..b46da6ba01c1e77771f7e6935e81147162f8df90 100644 (file)
@@ -133,26 +133,8 @@ catch(...) {
   return false;
 }
 
-static bool sendResponseToClient(int fd, char* response, uint16_t responseLen, size_t responseSize
-#ifdef HAVE_DNSCRYPT
-                                 , DnsCryptContext* dnscryptCtx,
-                                 std::shared_ptr<DnsCryptQuery> dnsCryptQuery
-#endif
-  )
+static bool sendResponseToClient(int fd, const char* response, uint16_t responseLen)
 {
-#ifdef HAVE_DNSCRYPT
-  if (dnscryptCtx && dnsCryptQuery) {
-    uint16_t encryptedResponseLen = 0;
-    int res = dnscryptCtx->encryptResponse(response, responseLen, responseSize, dnsCryptQuery, true, &encryptedResponseLen);
-    if (res == 0) {
-      responseLen = encryptedResponseLen;
-    } else {
-      /* dropping response */
-      vinfolog("Error encrypting the response, dropping.");
-      return false;
-    }
-  }
-#endif
   if (!putNonBlockingMsgLen(fd, responseLen, g_tcpSendTimeout))
     return false;
 
@@ -241,7 +223,7 @@ void* tcpClientThread(int pipefd)
 
           if (!decrypted) {
             if (response.size() > 0) {
-              sendResponseToClient(ci.fd, reinterpret_cast<char*>(response.data()), response.size(), response.size(), nullptr, nullptr);
+              sendResponseToClient(ci.fd, reinterpret_cast<char*>(response.data()), response.size());
             }
             break;
           }
@@ -281,11 +263,13 @@ void* tcpClientThread(int pipefd)
        }
 
        if(dq.dh->qr) { // something turned it into a response
-          sendResponseToClient(ci.fd, queryBuffer, dq.len, dq.size
+          restoreFlags(dh, origFlags);
 #ifdef HAVE_DNSCRYPT
-                               , ci.cs->dnscryptCtx, dnsCryptQuery
+          if (!encryptResponse(queryBuffer, &dq.len, dq.size, true, dnsCryptQuery)) {
+            goto drop;
+          }
 #endif
-            );
+          sendResponseToClient(ci.fd, query, dq.len);
          g_stats.selfAnswered++;
          goto drop;
        }
@@ -319,11 +303,12 @@ void* tcpClientThread(int pipefd)
           uint16_t cachedResponseSize = sizeof cachedResponse;
           uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL;
           if (packetCache->get(dq, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, allowExpired)) {
-            sendResponseToClient(ci.fd, cachedResponse, cachedResponseSize, sizeof cachedResponse
 #ifdef HAVE_DNSCRYPT
-                                 , ci.cs->dnscryptCtx, dnsCryptQuery
+            if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery)) {
+              goto drop;
+            }
 #endif
-              );
+            sendResponseToClient(ci.fd, cachedResponse, cachedResponseSize);
             g_stats.cacheHits++;
             goto drop;
           }
@@ -394,11 +379,13 @@ void* tcpClientThread(int pipefd)
         }
 
         size_t responseSize = rlen;
+        uint16_t addRoom = 0;
 #ifdef HAVE_DNSCRYPT
-        if (ci.cs->dnscryptCtx && (UINT16_MAX - DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) > rlen) {
-          responseSize += DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
+        if (dnsCryptQuery && (UINT16_MAX - rlen) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) {
+          addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
         }
 #endif
+        responseSize += addRoom;
         char answerbuffer[responseSize];
         readn2WithTimeout(dsock, answerbuffer, rlen, ds->tcpRecvTimeout);
         char* response = answerbuffer;
@@ -414,11 +401,7 @@ void* tcpClientThread(int pipefd)
           break;
         }
 
-        if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded,
-#ifdef HAVE_DNSCRYPT
-                           dnsCryptQuery,
-#endif
-                           rewrittenResponse)) {
+        if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded, rewrittenResponse, addRoom)) {
           break;
         }
 
@@ -426,11 +409,12 @@ void* tcpClientThread(int pipefd)
          packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true, dh->rcode == RCode::ServFail);
        }
 
-        if (!sendResponseToClient(ci.fd, response, responseLen, responseSize
 #ifdef HAVE_DNSCRYPT
-                                 , ci.cs->dnscryptCtx, dnsCryptQuery
+        if (!encryptResponse(response, &responseLen, responseSize, true, dnsCryptQuery)) {
+          goto drop;
+        }
 #endif
-              )) {
+        if (!sendResponseToClient(ci.fd, response, responseLen)) {
           break;
         }
 
index c510feadf2551b2d371e6cd44aa5a7a86f433d25..2537f8d5b825d4afeeb9bc20f1d59215f2d21d11 100644 (file)
@@ -187,15 +187,22 @@ bool responseContentMatches(const char* response, const uint16_t responseLen, co
   return true;
 }
 
-bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded,
-#ifdef HAVE_DNSCRYPT
-                   std::shared_ptr<DnsCryptQuery> dnsCryptQuery,
-#endif
-                   std::vector<uint8_t>& rewrittenResponse)
+void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
 {
   static const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
   static const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
   static const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
+  uint16_t * flags = getFlagsFromDNSHeader(dh);
+  /* clear the flags we are about to restore */
+  *flags &= restoreFlagsMask;
+  /* only keep the flags we want to restore */
+  origFlags &= ~restoreFlagsMask;
+  /* set the saved flags as they were */
+  *flags |= origFlags;
+}
+
+bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom)
+{
   struct dnsheader* dh = (struct dnsheader*) *response;
 
   if (*responseLen < sizeof(dnsheader)) {
@@ -209,13 +216,7 @@ bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize,
     }
   }
 
-  uint16_t * flags = getFlagsFromDNSHeader(dh);
-  /* clear the flags we are about to restore */
-  *flags &= restoreFlagsMask;
-  /* only keep the flags we want to restore */
-  origFlags &= ~restoreFlagsMask;
-  /* set the saved flags as they were */
-  *flags |= origFlags;
+  restoreFlags(dh, origFlags);
 
   if (ednsAdded) {
     const char * optStart = NULL;
@@ -236,12 +237,10 @@ bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize,
         /* Removing an intermediary RR could lead to compression error */
         if (rewriteResponseWithoutEDNS(*response, *responseLen, rewrittenResponse) == 0) {
           *responseLen = rewrittenResponse.size();
-#ifdef HAVE_DNSCRYPT
-          if (dnsCryptQuery && (UINT16_MAX - DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) > *responseLen) {
-            rewrittenResponse.reserve(*responseLen + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE);
+          if (addRoom && (UINT16_MAX - *responseLen) > addRoom) {
+            rewrittenResponse.reserve(*responseLen + addRoom);
           }
           *responseSize = rewrittenResponse.capacity();
-#endif
           *response = reinterpret_cast<char*>(rewrittenResponse.data());
         }
         else {
@@ -254,27 +253,26 @@ bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize,
   return true;
 }
 
-static bool sendUDPResponse(int origFD, char* response, uint16_t responseLen, size_t responseSize,
 #ifdef HAVE_DNSCRYPT
-                            std::shared_ptr<DnsCryptQuery> dnsCryptQuery,
-#endif
-                            int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
+bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DnsCryptQuery> dnsCryptQuery)
 {
-#ifdef HAVE_DNSCRYPT
-  uint16_t encryptedResponseLen = 0;
-  if(dnsCryptQuery) {
-    int res = dnsCryptQuery->ctx->encryptResponse(response, responseLen, responseSize, dnsCryptQuery, false, &encryptedResponseLen);
-
+  if (dnsCryptQuery) {
+    uint16_t encryptedResponseLen = 0;
+    int res = dnsCryptQuery->ctx->encryptResponse(response, *responseLen, responseSize, dnsCryptQuery, tcp, &encryptedResponseLen);
     if (res == 0) {
-      responseLen = encryptedResponseLen;
+      *responseLen = encryptedResponseLen;
     } else {
       /* dropping response */
       vinfolog("Error encrypting the response, dropping.");
       return false;
     }
   }
+  return true;
+}
 #endif
 
+static bool sendUDPResponse(int origFD, char* response, uint16_t responseLen, int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
+{
   if(delayMsec && g_delay) {
     DelayedPacket dp{origFD, string(response,responseLen), origRemote, origDest};
     g_delay->submit(dp, delayMsec);
@@ -339,11 +337,13 @@ void* responderThread(std::shared_ptr<DownstreamState> state)
 
     dh->id = ids->origID;
 
-    if (!fixUpResponse(&response, &responseLen, &responseSize, ids->qname, ids->origFlags, ids->ednsAdded,
+    uint16_t addRoom = 0;
 #ifdef HAVE_DNSCRYPT
-                       ids->dnsCryptQuery,
+    if (ids->dnsCryptQuery) {
+      addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
+    }
 #endif
-                       rewrittenResponse)) {
+    if (!fixUpResponse(&response, &responseLen, &responseSize, ids->qname, ids->origFlags, ids->ednsAdded, rewrittenResponse, addRoom)) {
       continue;
     }
 
@@ -351,11 +351,12 @@ void* responderThread(std::shared_ptr<DownstreamState> state)
       ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode == RCode::ServFail);
     }
 
-    sendUDPResponse(origFD, response, responseLen, responseSize,
 #ifdef HAVE_DNSCRYPT
-                    ids->dnsCryptQuery,
+    if (!encryptResponse(response, &responseLen, responseSize, false, ids->dnsCryptQuery)) {
+      continue;
+    }
 #endif
-                    ids->delayMsec, ids->origDest, ids->origRemote);
+    sendUDPResponse(origFD, response, responseLen, ids->delayMsec, ids->origDest, ids->origRemote);
 
     g_stats.responses++;
 
@@ -808,7 +809,7 @@ try
             if(!HarvestDestinationAddress(&msgh, &dest)) {
               dest.sin4.sin_family = 0;
             }
-            sendUDPResponse(cs->udpFD, reinterpret_cast<char*>(response.data()), response.size(), response.size(), nullptr, 0, dest, remote);
+            sendUDPResponse(cs->udpFD, reinterpret_cast<char*>(response.data()), response.size(), 0, dest, remote);
           }
           continue;
         }
@@ -852,18 +853,20 @@ try
       if(dq.dh->qr) { // something turned it into a response
         char* response = query;
         uint16_t responseLen = dq.len;
-        uint16_t responseSize = dq.size;
         g_stats.selfAnswered++;
 
+        restoreFlags(dh, origFlags);
+
         ComboAddress dest;
         if(!HarvestDestinationAddress(&msgh, &dest)) {
           dest.sin4.sin_family = 0;
         }
-        sendUDPResponse(cs->udpFD, response, responseLen, responseSize,
 #ifdef HAVE_DNSCRYPT
-                        dnsCryptQuery,
+        if (!encryptResponse(response, &responseLen, dq.size, false, dnsCryptQuery)) {
+          continue;
+        }
 #endif
-                        0, dest, remote);
+        sendUDPResponse(cs->udpFD, response, responseLen, 0, dest, remote);
         continue;
       }
 
@@ -892,11 +895,12 @@ try
           if(!HarvestDestinationAddress(&msgh, &dest)) {
             dest.sin4.sin_family = 0;
           }
-          sendUDPResponse(cs->udpFD, cachedResponse, cachedResponseSize, sizeof cachedResponse,
 #ifdef HAVE_DNSCRYPT
-                          dnsCryptQuery,
+          if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, false, dnsCryptQuery)) {
+            continue;
+          }
 #endif
-                          0, dest, remote);
+          sendUDPResponse(cs->udpFD, cachedResponse, cachedResponseSize, 0, dest, remote);
           g_stats.cacheHits++;
           g_stats.latency0_1++;  // we're not going to measure this
           doLatencyAverages(0);  // same
index fe43e83cf1304b49e8f48c1a9da13dbcf53e6974..bea724d919c6d9a442180aa65dc6d23603950cb6 100644 (file)
@@ -509,14 +509,12 @@ void resetLuaSideEffect(); // reset to indeterminate state
 
 bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote);
 bool processQuery(LocalStateHolder<NetmaskTree<DynBlock> >& localDynBlock, LocalStateHolder<vector<pair<std::shared_ptr<DNSRule>, std::shared_ptr<DNSAction> > > >& localRulactions, blockfilter_t blockFilter, DNSQuestion& dq, const ComboAddress& remote, string& poolname, int* delayMsec, const struct timespec& now);
-bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded,
-#ifdef HAVE_DNSCRYPT
-                   std::shared_ptr<DnsCryptQuery> dnsCryptQuery,
-#endif
-                   std::vector<uint8_t>& rewrittenResponse);
+bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom);
+void restoreFlags(struct dnsheader* dh, uint16_t origFlags);
 
 #ifdef HAVE_DNSCRYPT
 extern std::vector<std::tuple<ComboAddress,DnsCryptContext,bool>> g_dnsCryptLocals;
 
 int handleDnsCryptQuery(DnsCryptContext* ctx, char* packet, uint16_t len, std::shared_ptr<DnsCryptQuery>& query, uint16_t* decryptedQueryLen, bool tcp, std::vector<uint8_t>& reponse);
+bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DnsCryptQuery> dnsCryptQuery);
 #endif
index 03ddb39fd03414ad9b6533ead6018f2d588837ff..24e66d2543a4b2828679632538297d067d755224 100644 (file)
@@ -835,3 +835,41 @@ class TestAdvancedStringOnlyServer(DNSDistTest):
         receivedQuery.id = query.id
         self.assertEquals(query, receivedQuery)
         self.assertEquals(response, receivedResponse)
+
+class TestAdvancedRestoreFlagsOnSelfResponse(DNSDistTest):
+
+    _config_template = """
+    addAction(AllRule(), DisableValidationAction())
+    addAction(AllRule(), SpoofAction("192.0.2.1"))
+    newServer{address="127.0.0.1:%s"}
+    """
+
+    def testAdvancedRestoreFlagsOnSpoofResponse(self):
+        """
+        Advanced: Restore flags on spoofed response
+
+        Send a query with CD flag cleared, dnsdist is
+        instructed to set it, then to spoof the response,
+        check that response has the flag cleared.
+        """
+        name = 'spoofed.restoreflags.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedQuery = dns.message.make_query(name, 'A', 'IN')
+
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.assertTrue(receivedResponse)
+        self.assertEquals(response, receivedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.assertTrue(receivedResponse)
+        self.assertEquals(response, receivedResponse)