]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Refactor duplicated response handling code (UDP/TCP)
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 15 Mar 2016 14:56:37 +0000 (15:56 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 15 Mar 2016 14:56:37 +0000 (15:56 +0100)
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
regression-tests.dnsdist/dnscrypt.py
regression-tests.dnsdist/test_DNSCrypt.py

index d4320e7aabbf7340725881a6640d1e73a0fbf8dc..1df6b06d24ca97c3543160e3a8b7b49cf5b768f0 100644 (file)
@@ -133,6 +133,33 @@ catch(...) {
   return false;
 }
 
+static bool sendResponseToClient(int fd, char* response, uint16_t responseLen, size_t responseSize
+#ifdef HAVE_DNSCRYPT
+                                 , DnsCryptContext* dnscryptCtx,
+                                 std::shared_ptr<DnsCryptQuery> dnsCryptQuery
+#endif
+  )
+{
+#ifdef HAVE_DNSCRYPT
+  if (dnscryptCtx && dnsCryptQuery) {
+    uint16_t encryptedResponseLen = 0;
+    int res = dnscryptCtx->encryptResponse(response, responseLen, responseSize, dnsCryptQuery, true, &encryptedResponseLen);
+    if (res == 0) {
+      responseLen = encryptedResponseLen;
+    } else {
+      /* dropping response */
+      vinfolog("Error encrypting the response, dropping.");
+      return false;
+    }
+  }
+#endif
+  if (!putNonBlockingMsgLen(fd, responseLen, g_tcpSendTimeout))
+    return false;
+
+  writen2WithTimeout(fd, response, responseLen, g_tcpSendTimeout);
+  return true;
+}
+
 std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
 
 void* tcpClientThread(int pipefd)
@@ -173,9 +200,6 @@ void* tcpClientThread(int pipefd)
 
     uint16_t qlen, rlen;
     string poolname;
-    const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
-    const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
-    const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
     string largerQuery;
     vector<uint8_t> rewrittenResponse;
     bool ednsAdded = false;
@@ -214,8 +238,7 @@ void* tcpClientThread(int pipefd)
 
           if (!decrypted) {
             if (response.size() > 0) {
-              if (putNonBlockingMsgLen(ci.fd, response.size(), g_tcpSendTimeout))
-                writen2WithTimeout(ci.fd, (const char *) response.data(), response.size(), g_tcpSendTimeout);
+              sendResponseToClient(ci.fd, reinterpret_cast<char*>(response.data()), response.size(), response.size(), nullptr, nullptr);
             }
             break;
           }
@@ -323,9 +346,11 @@ void* tcpClientThread(int pipefd)
        }
 
        if(dq.dh->qr) { // something turned it into a response
-         if (putNonBlockingMsgLen(ci.fd, dq.len, g_tcpSendTimeout))
-           writen2WithTimeout(ci.fd, query, dq.len, g_tcpSendTimeout);
-
+          sendResponseToClient(ci.fd, queryBuffer, dq.len, dq.size
+#ifdef HAVE_DNSCRYPT
+                               , ci.cs->dnscryptCtx, dnsCryptQuery
+#endif
+            );
          g_stats.selfAnswered++;
          goto drop;
        }
@@ -359,8 +384,11 @@ void* tcpClientThread(int pipefd)
           uint16_t cachedResponseSize = sizeof cachedResponse;
           uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL;
           if (packetCache->get(dq, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, allowExpired)) {
-            if (putNonBlockingMsgLen(ci.fd, cachedResponseSize, g_tcpSendTimeout))
-              writen2WithTimeout(ci.fd, cachedResponse, cachedResponseSize, g_tcpSendTimeout);
+            sendResponseToClient(ci.fd, cachedResponse, cachedResponseSize, sizeof cachedResponse
+#ifdef HAVE_DNSCRYPT
+                                 , ci.cs->dnscryptCtx, dnsCryptQuery
+#endif
+              );
             g_stats.cacheHits++;
             goto drop;
           }
@@ -430,7 +458,7 @@ void* tcpClientThread(int pipefd)
           goto retry;
         }
 
-        uint16_t responseSize = rlen;
+        size_t responseSize = rlen;
 #ifdef HAVE_DNSCRYPT
         if (ci.cs->dnscryptCtx && (UINT16_MAX - DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) > rlen) {
           responseSize += DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
@@ -438,14 +466,6 @@ void* tcpClientThread(int pipefd)
 #endif
         char answerbuffer[responseSize];
         readn2WithTimeout(dsock, answerbuffer, rlen, ds->tcpRecvTimeout);
-        struct dnsheader* responseHeaders = (struct dnsheader*)answerbuffer;
-        uint16_t * responseFlags = getFlagsFromDNSHeader(responseHeaders);
-        /* clear the flags we are about to restore */
-        *responseFlags &= restoreFlagsMask;
-        /* only keep the flags we want to restore */
-        origFlags &= ~restoreFlagsMask;
-        /* set the saved flags as they were */
-        *responseFlags |= origFlags;
         char* response = answerbuffer;
         uint16_t responseLen = rlen;
         --ds->outstanding;
@@ -455,86 +475,29 @@ void* tcpClientThread(int pipefd)
           break;
         }
 
-        dh = (struct dnsheader*) response;
-        DNSName rqname;
-        uint16_t rqtype, rqclass;
-        try {
-          rqname=DNSName(response, responseLen, sizeof(dnsheader), false, &rqtype, &rqclass, &consumed);
-        }
-        catch(std::exception& e) {
-          if(rlen > (ssize_t)sizeof(dnsheader))
-            infolog("Backend %s sent us a response with id %d that did not parse: %s", ds->remote.toStringWithPort(), ntohs(dh->id), e.what());
-          g_stats.nonCompliantResponses++;
-          break;
-        }
-
-        if (rqtype != qtype || rqclass != qclass || rqname != qname) {
+        if (!responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote)) {
           break;
         }
 
-        if (ednsAdded) {
-          const char * optStart = NULL;
-          size_t optLen = 0;
-          bool last = false;
-
-          int res = locateEDNSOptRR(response, responseLen, &optStart, &optLen, &last);
-
-          if (res == 0) {
-            if (last) {
-              /* simply remove the last AR */
-              responseLen -= optLen;
-              uint16_t arcount = ntohs(responseHeaders->arcount);
-              arcount--;
-              responseHeaders->arcount = htons(arcount);
-            }
-            else {
-              /* Removing an intermediary RR could lead to compression error */
-              if (rewriteResponseWithoutEDNS(response, responseLen, rewrittenResponse) == 0) {
+        if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded,
 #ifdef HAVE_DNSCRYPT
-                if (ci.cs->dnscryptCtx && rewrittenResponse.capacity() < responseSize && ci.cs->dnscryptCtx) {
-                  /* we preserve room for dnscrypt */
-                  rewrittenResponse.reserve(responseSize);
-                }
+                           dnsCryptQuery,
 #endif
-                responseSize = responseLen;
-                responseLen = rewrittenResponse.size();
-                response = reinterpret_cast<char*>(rewrittenResponse.data());
-              }
-              else {
-                warnlog("Error rewriting content");
-              }
-            }
-          }
+                           rewrittenResponse)) {
+          break;
         }
 
-       if(g_fixupCase) {
-         string realname = qname.toDNSString();
-         if (responseLen >= (sizeof(dnsheader) + realname.length())) {
-           memcpy(response + sizeof(dnsheader), realname.c_str(), realname.length());
-         }
-       }
-
        if (packetCache && !dq.skipCache) {
          packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true, dh->rcode == RCode::ServFail);
        }
 
+        if (!sendResponseToClient(ci.fd, response, responseLen, responseSize
 #ifdef HAVE_DNSCRYPT
-        if (ci.cs->dnscryptCtx) {
-          uint16_t encryptedResponseLen = 0;
-          int res = ci.cs->dnscryptCtx->encryptResponse(response, responseLen, responseSize, dnsCryptQuery, true, &encryptedResponseLen);
-
-          if (res == 0) {
-            responseLen = encryptedResponseLen;
-          } else {
-            /* dropping response */
-            vinfolog("Error encrypting the response, dropping.");
-            break;
-          }
-        }
+                                 , ci.cs->dnscryptCtx, dnsCryptQuery
 #endif
-
-        if (putNonBlockingMsgLen(ci.fd, responseLen, ds->tcpSendTimeout))
-          writen2WithTimeout(ci.fd, response, responseLen, ds->tcpSendTimeout);
+              )) {
+          break;
+        }
 
         g_stats.responses++;
         struct timespec answertime;
index 0fce07e9b37edc5852ebf6c4b7b6c9380305fe84..cb42bd76b90f6fc7d1f9e679d2639502f0e4334e 100644 (file)
@@ -159,6 +159,136 @@ static void doLatencyAverages(double udiff)
   doAvg(g_stats.latencyAvg1000000, udiff, 1000000);
 }
 
+bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote)
+{
+  uint16_t rqtype, rqclass;
+  unsigned int consumed;
+  DNSName rqname;
+  const struct dnsheader* dh = (struct dnsheader*) response;
+
+  if (responseLen < sizeof(dnsheader)) {
+    return false;
+  }
+
+  try {
+    rqname=DNSName(response, responseLen, sizeof(dnsheader), false, &rqtype, &rqclass, &consumed);
+  }
+  catch(std::exception& e) {
+    if(responseLen > (ssize_t)sizeof(dnsheader))
+      infolog("Backend %s sent us a response with id %d that did not parse: %s", remote.toStringWithPort(), ntohs(dh->id), e.what());
+    g_stats.nonCompliantResponses++;
+    return false;
+  }
+
+  if (rqtype != qtype || rqclass != qclass || rqname != qname) {
+    return false;
+  }
+
+  return true;
+}
+
+bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded,
+#ifdef HAVE_DNSCRYPT
+                   std::shared_ptr<DnsCryptQuery> dnsCryptQuery,
+#endif
+                   std::vector<uint8_t>& rewrittenResponse)
+{
+  static const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
+  static const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
+  static const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
+  struct dnsheader* dh = (struct dnsheader*) *response;
+
+  if (*responseLen < sizeof(dnsheader)) {
+    return false;
+  }
+
+  if(g_fixupCase) {
+    string realname = qname.toDNSString();
+    if (*responseLen >= (sizeof(dnsheader) + realname.length())) {
+      memcpy(*response + sizeof(dnsheader), realname.c_str(), realname.length());
+    }
+  }
+
+  uint16_t * flags = getFlagsFromDNSHeader(dh);
+  /* clear the flags we are about to restore */
+  *flags &= restoreFlagsMask;
+  /* only keep the flags we want to restore */
+  origFlags &= ~restoreFlagsMask;
+  /* set the saved flags as they were */
+  *flags |= origFlags;
+
+  if (ednsAdded) {
+    const char * optStart = NULL;
+    size_t optLen = 0;
+    bool last = false;
+
+    int res = locateEDNSOptRR(*response, *responseLen, &optStart, &optLen, &last);
+
+    if (res == 0) {
+      if (last) {
+        /* simply remove the last AR */
+        *responseLen -= optLen;
+        uint16_t arcount = ntohs(dh->arcount);
+        arcount--;
+        dh->arcount = htons(arcount);
+      }
+      else {
+        /* Removing an intermediary RR could lead to compression error */
+        if (rewriteResponseWithoutEDNS(*response, *responseLen, rewrittenResponse) == 0) {
+          *responseLen = rewrittenResponse.size();
+#ifdef HAVE_DNSCRYPT
+          if (dnsCryptQuery && (UINT16_MAX - DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) > *responseLen) {
+            rewrittenResponse.reserve(*responseLen + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE);
+          }
+          *responseSize = rewrittenResponse.capacity();
+#endif
+          *response = reinterpret_cast<char*>(rewrittenResponse.data());
+        }
+        else {
+          warnlog("Error rewriting content");
+        }
+      }
+    }
+  }
+
+  return true;
+}
+
+static bool sendUDPResponse(int origFD, char* response, uint16_t responseLen, size_t responseSize,
+#ifdef HAVE_DNSCRYPT
+                            std::shared_ptr<DnsCryptQuery> dnsCryptQuery,
+#endif
+                            int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
+{
+#ifdef HAVE_DNSCRYPT
+  uint16_t encryptedResponseLen = 0;
+  if(dnsCryptQuery) {
+    int res = dnsCryptQuery->ctx->encryptResponse(response, responseLen, responseSize, dnsCryptQuery, false, &encryptedResponseLen);
+
+    if (res == 0) {
+      responseLen = encryptedResponseLen;
+    } else {
+      /* dropping response */
+      vinfolog("Error encrypting the response, dropping.");
+      return false;
+    }
+  }
+#endif
+
+  if(delayMsec && g_delay) {
+    DelayedPacket dp{origFD, string(response,responseLen), origRemote, origDest};
+    g_delay->submit(dp, delayMsec);
+  }
+  else {
+    if(origDest.sin4.sin_family == 0)
+      sendto(origFD, response, responseLen, 0, (struct sockaddr*)&origRemote, origRemote.getSocklen());
+    else
+      sendfromto(origFD, response, responseLen, 0, origDest, origRemote);
+  }
+
+  return true;
+}
+
 // listens on a dedicated socket, lobs answers from downstream servers to original requestors
 void* responderThread(std::shared_ptr<DownstreamState> state)
 {
@@ -168,24 +298,18 @@ void* responderThread(std::shared_ptr<DownstreamState> state)
   char packet[4096];
 #endif
   static_assert(sizeof(packet) <= UINT16_MAX, "Packet size should fit in a uint16_t");
-  const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
-  const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
-  const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
   vector<uint8_t> rewrittenResponse;
-  uint16_t qtype, qclass;
 
   struct dnsheader* dh = (struct dnsheader*)packet;
   for(;;) {
     ssize_t got = recv(state->fd, packet, sizeof(packet), 0);
     char * response = packet;
-#ifdef HAVE_DNSCRYPT
-    uint16_t responseSize = sizeof(packet);
-#endif
+    size_t responseSize = sizeof(packet);
 
     if (got < (ssize_t) sizeof(dnsheader))
       continue;
 
-    size_t responseLen = (size_t) got;
+    uint16_t responseLen = (size_t) got;
 
     if(dh->id >= state->idStates.size())
       continue;
@@ -202,108 +326,38 @@ void* responderThread(std::shared_ptr<DownstreamState> state)
        mostly mess up the outstanding counter.
     */
     ids->age = 0;
-    unsigned int consumed;
-    DNSName qname;
-    try {
-      qname=DNSName(packet, responseLen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-    }
-    catch(std::exception& e) {
-      if(got > (ssize_t)sizeof(dnsheader))
-        infolog("Backend %s sent us a response with id %d that did not parse: %s", state->remote.toStringWithPort(), ntohs(dh->id), e.what());
-      g_stats.nonCompliantResponses++;
+
+    if (!responseContentMatches(response, responseLen, ids->qname, ids->qtype, ids->qclass, state->remote)) {
       continue;
     }
-    if (qtype != ids->qtype || qclass != ids->qclass || qname != ids->qname)
-      continue;
 
     --state->outstanding;  // you'd think an attacker could game this, but we're using connected socket
 
-    if(g_fixupCase) {
-      string realname = ids->qname.toDNSString();
-      if (responseLen >= (sizeof(dnsheader) + realname.length())) {
-        memcpy(packet+12, realname.c_str(), realname.length());
-      }
-    }
-
     if(dh->tc && g_truncateTC) {
-      truncateTC(packet, (uint16_t*) &responseLen);
+      truncateTC(response, (uint16_t*) &responseLen);
     }
-    uint16_t * flags = getFlagsFromDNSHeader(dh);
-    uint16_t origFlags = ids->origFlags;
-    /* clear the flags we are about to restore */
-    *flags &= restoreFlagsMask;
-    /* only keep the flags we want to restore */
-    origFlags &= ~restoreFlagsMask;
-    /* set the saved flags as they were */
-    *flags |= origFlags;
 
     dh->id = ids->origID;
 
-    if (ids->ednsAdded) {
-      const char * optStart = NULL;
-      size_t optLen = 0;
-      bool last = false;
-
-      int res = locateEDNSOptRR(response, responseLen, &optStart, &optLen, &last);
-
-      if (res == 0) {
-        if (last) {
-          /* simply remove the last AR */
-          responseLen -= optLen;
-          uint16_t arcount = ntohs(dh->arcount);
-          arcount--;
-          dh->arcount = htons(arcount);
-        }
-        else {
-          /* Removing an intermediary RR could lead to compression error */
-          if (rewriteResponseWithoutEDNS(response, responseLen, rewrittenResponse) == 0) {
-            responseLen = rewrittenResponse.size();
+    if (!fixUpResponse(&response, &responseLen, &responseSize, ids->qname, ids->origFlags, ids->ednsAdded,
 #ifdef HAVE_DNSCRYPT
-            if (ids->dnsCryptQuery && (UINT16_MAX - DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) > responseLen) {
-              rewrittenResponse.reserve(responseLen + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE);
-            }
-            responseSize = rewrittenResponse.capacity();
+                       ids->dnsCryptQuery,
 #endif
-            response = reinterpret_cast<char*>(rewrittenResponse.data());
-          }
-          else {
-            warnlog("Error rewriting content");
-          }
-        }
-      }
+                       rewrittenResponse)) {
+      continue;
     }
 
-    g_stats.responses++;
-
     if (ids->packetCache && !ids->skipCache) {
-      ids->packetCache->insert(ids->cacheKey, qname, qtype, qclass, response, responseLen, false, dh->rcode == RCode::ServFail);
+      ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode == RCode::ServFail);
     }
 
+    sendUDPResponse(origFD, response, responseLen, responseSize,
 #ifdef HAVE_DNSCRYPT
-    uint16_t encryptedResponseLen = 0;
-    if(ids->dnsCryptQuery) {
-      int res = ids->dnsCryptQuery->ctx->encryptResponse(response, responseLen, responseSize, ids->dnsCryptQuery, false, &encryptedResponseLen);
-
-      if (res == 0) {
-        responseLen = encryptedResponseLen;
-      } else {
-        /* dropping response */
-        vinfolog("Error encrypting the response, dropping.");
-        continue;
-      }
-    }
+                    ids->dnsCryptQuery,
 #endif
+                    ids->delayMsec, ids->origDest, ids->origRemote);
 
-    if(ids->delayMsec && g_delay) {
-      DelayedPacket dp{origFD, string(response,responseLen), ids->origRemote, ids->origDest};
-      g_delay->submit(dp, ids->delayMsec);
-    }
-    else {
-      if(ids->origDest.sin4.sin_family == 0)
-       sendto(origFD, response, responseLen, 0, (struct sockaddr*)&ids->origRemote, ids->origRemote.getSocklen());
-      else
-       sendfromto(origFD, response, responseLen, 0, ids->origDest, ids->origRemote);
-    }
+    g_stats.responses++;
 
     double udiff = ids->sentTime.udiff();
     vinfolog("Got answer from %s, relayed to %s, took %f usec", state->remote.toStringWithPort(), ids->origRemote.toStringWithPort(), udiff);
@@ -312,7 +366,7 @@ void* responderThread(std::shared_ptr<DownstreamState> state)
       struct timespec ts;
       clock_gettime(CLOCK_MONOTONIC, &ts);
       std::lock_guard<std::mutex> lock(g_rings.respMutex);
-      g_rings.respRing.push_back({ts, ids->origRemote, qname, qtype, (unsigned int)udiff, (unsigned int)got, *dh, state->remote});
+      g_rings.respRing.push_back({ts, ids->origRemote, ids->qname, ids->qtype, (unsigned int)udiff, (unsigned int)got, *dh, state->remote});
     }
     if(dh->rcode == RCode::ServFail)
       g_stats.servfailResponses++;
@@ -682,10 +736,14 @@ try
         if (!decrypted) {
           if (response.size() > 0) {
             ComboAddress dest;
-            if(HarvestDestinationAddress(&msgh, &dest))
-              sendfromto(cs->udpFD, (const char *) response.data(), response.size(), 0, dest, remote);
-            else
-              sendto(cs->udpFD, response.data(), response.size(), 0, (struct sockaddr*)&remote, remote.getSocklen());
+            if(!HarvestDestinationAddress(&msgh, &dest)) {
+              dest.sin4.sin_family = 0;
+            }
+            sendUDPResponse(cs->udpFD, reinterpret_cast<char*>(response.data()), response.size(), response.size(),
+#ifdef HAVE_DNSCRYPT
+                            nullptr,
+#endif
+                            0, dest, remote);
           }
           continue;
         }
@@ -797,31 +855,18 @@ try
       if(dq.dh->qr) { // something turned it into a response
         char* response = query;
         uint16_t responseLen = dq.len;
-#ifdef HAVE_DNSCRYPT
         uint16_t responseSize = dq.size;
-#endif
         g_stats.selfAnswered++;
 
-#ifdef HAVE_DNSCRYPT
-        uint16_t encryptedResponseLen = 0;
-
-        if(dnsCryptQuery) {
-          int res = cs->dnscryptCtx->encryptResponse(response, responseLen, responseSize, dnsCryptQuery, false, &encryptedResponseLen);
-
-          if (res == 0) {
-            responseLen = encryptedResponseLen;
-          } else {
-            /* dropping response */
-            continue;
-          }
+        ComboAddress dest;
+        if(!HarvestDestinationAddress(&msgh, &dest)) {
+          dest.sin4.sin_family = 0;
         }
+        sendUDPResponse(cs->udpFD, response, responseLen, responseSize,
+#ifdef HAVE_DNSCRYPT
+                        dnsCryptQuery,
 #endif
-        ComboAddress dest;
-        if(HarvestDestinationAddress(&msgh, &dest))
-          sendfromto(cs->udpFD, response, responseLen, 0, dest, remote);
-        else
-          sendto(cs->udpFD, response, responseLen, 0, (struct sockaddr*)&remote, remote.getSocklen());
-
+                        0, dest, remote);
         continue;
       }
 
@@ -847,10 +892,14 @@ try
         uint32_t allowExpired = ss ? 0 : g_staleCacheEntriesTTL;
         if (packetCache->get(dq, consumed, dh->id, cachedResponse, &cachedResponseSize, &cacheKey, allowExpired)) {
           ComboAddress dest;
-          if(HarvestDestinationAddress(&msgh, &dest))
-            sendfromto(cs->udpFD, cachedResponse, cachedResponseSize, 0, dest, remote);
-          else
-            sendto(cs->udpFD, cachedResponse, cachedResponseSize, 0, (struct sockaddr*)&remote, remote.getSocklen());
+          if(!HarvestDestinationAddress(&msgh, &dest)) {
+            dest.sin4.sin_family = 0;
+          }
+          sendUDPResponse(cs->udpFD, cachedResponse, cachedResponseSize, sizeof cachedResponse,
+#ifdef HAVE_DNSCRYPT
+                          dnsCryptQuery,
+#endif
+                          0, dest, remote);
           g_stats.cacheHits++;
           g_stats.latency0_1++;  // we're not going to measure this
           doLatencyAverages(0);  // same
index bddbfb229558187ba4b2e4f44d75156a68e16973..946ceed3bdc8f25c0110700d031d21de8274e781 100644 (file)
@@ -506,6 +506,13 @@ void setLuaSideEffect();   // set to report a side effect, cancelling all _no_ s
 bool getLuaNoSideEffect(); // set if there were only explicit declarations of _no_ side effect
 void resetLuaSideEffect(); // reset to indeterminate state
 
+bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote);
+bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded,
+#ifdef HAVE_DNSCRYPT
+                   std::shared_ptr<DnsCryptQuery> dnsCryptQuery,
+#endif
+                   std::vector<uint8_t>& rewrittenResponse);
+
 #ifdef HAVE_DNSCRYPT
 extern std::vector<std::tuple<ComboAddress,DnsCryptContext,bool>> g_dnsCryptLocals;
 
index f46bd6df785e0a69662840477fbf01b03039ef26..5426a5fc41c09d5650689f595ad971f6618e313e 100644 (file)
@@ -70,15 +70,35 @@ class DNSCryptClient(object):
         self._resolverPort = resolverPort
         self._resolverCertificates = []
         self._publicKey, self._privateKey = libnacl.crypto_box_keypair()
+        self._timeout = timeout
 
         addrType = self._addrToSocketType(self._resolverAddress)
         self._sock = socket.socket(addrType, socket.SOCK_DGRAM)
         self._sock.settimeout(timeout)
         self._sock.connect((self._resolverAddress, self._resolverPort))
 
-    def _sendQuery(self, queryContent):
-        self._sock.send(queryContent)
-        data = self._sock.recv(4096)
+    def _sendQuery(self, queryContent, tcp=False):
+        if tcp:
+            addrType = self._addrToSocketType(self._resolverAddress)
+            sock = socket.socket(addrType, socket.SOCK_STREAM)
+            sock.settimeout(self._timeout)
+            sock.connect((self._resolverAddress, self._resolverPort))
+            sock.send(struct.pack("!H", len(queryContent)))
+        else:
+            sock = self._sock
+
+        sock.send(queryContent)
+
+        data = None
+        if tcp:
+            got = sock.recv(2)
+            print(len(got))
+            if got:
+                (rlen,) = struct.unpack("!H", got)
+                data = sock.recv(rlen)
+        else:
+            data = sock.recv(4096)
+
         return data
 
     def _hasValidResolverCertificate(self):
@@ -124,12 +144,12 @@ class DNSCryptClient(object):
         nonce = libnacl.utils.rand_nonce()
         return nonce[:(DNSCryptClient.DNSCRYPT_NONCE_SIZE / 2)]
 
-    def _encryptQuery(self, queryContent, resolverCert, nonce):
+    def _encryptQuery(self, queryContent, resolverCert, nonce, tcp=False):
         header = resolverCert.clientMagic + self._publicKey + nonce
         requiredSize = len(header) + self.DNSCRYPT_MAC_SIZE + len(queryContent)
         paddingSize = self.DNSCRYPT_PADDED_BLOCK_SIZE - (len(queryContent) % self.DNSCRYPT_PADDED_BLOCK_SIZE)
         # padding size should be DNSCRYPT_PADDED_BLOCK_SIZE <= padding size <= 4096
-        if requiredSize < self.DNSCRYPT_MIN_UDP_LENGTH:
+        if not tcp and requiredSize < self.DNSCRYPT_MIN_UDP_LENGTH:
             paddingSize += self.DNSCRYPT_MIN_UDP_LENGTH - requiredSize
             requiredSize = self.DNSCRYPT_MIN_UDP_LENGTH
 
@@ -168,7 +188,7 @@ class DNSCryptClient(object):
 
         return cleartext[:idx+1]
 
-    def query(self, queryContent):
+    def query(self, queryContent, tcp=False):
 
         if not self._hasValidResolverCertificate():
             self._getResolverCertificates()
@@ -177,7 +197,7 @@ class DNSCryptClient(object):
         resolverCert = self._getResolverCertificate()
         if resolverCert is None:
             raise Exception("No valid certificate found")
-        encryptedQuery = self._encryptQuery(queryContent, resolverCert, nonce)
-        encryptedResponse = self._sendQuery(encryptedQuery)
+        encryptedQuery = self._encryptQuery(queryContent, resolverCert, nonce, tcp)
+        encryptedResponse = self._sendQuery(encryptedQuery, tcp)
         response = self._decryptResponse(encryptedResponse, resolverCert, nonce)
         return response
index c8dba56de130c6d5ba85b403d17edcc92ec08efc..abf1da5ca4fd0bc95a18e1c30a997a5fa74b83e8 100644 (file)
@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 import time
-import unittest
 import dns
 import dns.message
 from dnsdisttests import DNSDistTest
@@ -43,7 +42,7 @@ class TestDNSCrypt(DNSDistTest):
                                     3600,
                                     dns.rdataclass.IN,
                                     dns.rdatatype.A,
-                                    '127.0.0.1')
+                                    '192.2.0.1')
         response.answer.append(rrset)
 
         self._toResponderQueue.put(response)
@@ -59,6 +58,19 @@ class TestDNSCrypt(DNSDistTest):
         self.assertEquals(query, receivedQuery)
         self.assertEquals(response, receivedResponse)
 
+        self._toResponderQueue.put(response)
+        data = client.query(query.to_wire(), tcp=True)
+        receivedResponse = dns.message.from_wire(data)
+        receivedQuery = None
+        if not self._fromResponderQueue.empty():
+            receivedQuery = self._fromResponderQueue.get(query)
+
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
     def testResponseLargerThanPaddedQuery(self):
         """
         DNSCrypt: response larger than query
@@ -95,6 +107,59 @@ class TestDNSCrypt(DNSDistTest):
         self.assertTrue(len(receivedResponse.authority) == 0)
         self.assertTrue(len(receivedResponse.additional) == 0)
 
-if __name__ == '__main__':
-    unittest.main()
-    exit(0)
+class TestDNSCryptWithCache(DNSDistTest):
+    _dnsDistPortDNSCrypt = 8443
+    _providerFingerprint = 'E1D7:2108:9A59:BF8D:F101:16FA:ED5E:EA6A:9F6C:C78F:7F91:AF6B:027E:62F4:69C3:B1AA'
+    _providerName = "2.provider.name"
+    _resolverCertificateSerial = 42
+    # valid from 60s ago until 2h from now
+    _resolverCertificateValidFrom = time.time() - 60
+    _resolverCertificateValidUntil = time.time() + 7200
+    _config_params = ['_resolverCertificateSerial', '_resolverCertificateValidFrom', '_resolverCertificateValidUntil', '_dnsDistPortDNSCrypt', '_providerName', '_testServerPort']
+    _config_template = """
+    generateDNSCryptCertificate("DNSCryptProviderPrivate.key", "DNSCryptResolver.cert", "DNSCryptResolver.key", %d, %d, %d)
+    addDNSCryptBind("127.0.0.1:%d", "%s", "DNSCryptResolver.cert", "DNSCryptResolver.key")
+    pc = newPacketCache(5, 86400, 1)
+    getPool(""):setCache(pc)
+    newServer{address="127.0.0.1:%s"}
+    """
+
+    def testCachedSimpleA(self):
+        """
+        DNSCrypt: encrypted A query served from cache
+        """
+        client = dnscrypt.DNSCryptClient(self._providerName, self._providerFingerprint, "127.0.0.1", 8443)
+        name = 'cacheda.dnscrypt.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.2.0.1')
+        response.answer.append(rrset)
+
+        # first query to fill the cache
+        self._toResponderQueue.put(response)
+        data = client.query(query.to_wire())
+        receivedResponse = dns.message.from_wire(data)
+        receivedQuery = None
+        if not self._fromResponderQueue.empty():
+            receivedQuery = self._fromResponderQueue.get(query)
+
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        # second query should get a cached response
+        data = client.query(query.to_wire())
+        receivedResponse = dns.message.from_wire(data)
+        receivedQuery = None
+        if not self._fromResponderQueue.empty():
+            receivedQuery = self._fromResponderQueue.get(query)
+
+        self.assertEquals(receivedQuery, None)
+        self.assertTrue(receivedResponse)
+        self.assertEquals(response, receivedResponse)