]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Refactoring to merge the UDP and TCP paths
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 25 Feb 2019 14:54:13 +0000 (15:54 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 4 Apr 2019 09:42:50 +0000 (11:42 +0200)
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/test-dnsdistrules_cc.cc

index 41465e16da06f86033d4e000ae73afd9e2b8140e..f13422943e0382e8e08b92cb5618597808a2f9b7 100644 (file)
@@ -296,16 +296,16 @@ void tcpClientThread(int pipefd)
     uint16_t qlen, rlen;
     vector<uint8_t> rewrittenResponse;
     shared_ptr<DownstreamState> ds;
-    ComboAddress dest;
-    dest.reset();
-    dest.sin4.sin_family = ci.remote.sin4.sin_family;
-    socklen_t len = dest.getSocklen();
     size_t queriesCount = 0;
     time_t connectionStartTime = time(NULL);
     std::vector<char> queryBuffer;
     std::vector<char> answerBuffer;
 
-    if (getsockname(ci.fd, (sockaddr*)&dest, &len)) {
+    ComboAddress dest;
+    dest.reset();
+    dest.sin4.sin_family = ci.remote.sin4.sin_family;
+    socklen_t socklen = dest.getSocklen();
+    if (getsockname(ci.fd, (sockaddr*)&dest, &socklen)) {
       dest = ci.cs->local;
     }
 
@@ -341,8 +341,6 @@ void tcpClientThread(int pipefd)
           break;
         }
 
-        bool ednsAdded = false;
-        bool ecsAdded = false;
         /* allocate a bit more memory to be able to spoof the content,
            or to add ECS without allocating a new buffer */
         queryBuffer.resize(qlen + 512);
@@ -350,219 +348,51 @@ void tcpClientThread(int pipefd)
         char* query = &queryBuffer[0];
         handler.read(query, qlen, g_tcpRecvTimeout, remainingTime);
 
-        /* we need this one to be accurate ("real") for the protobuf message */
-       struct timespec queryRealTime;
-       struct timespec now;
-       gettime(&now);
-       gettime(&queryRealTime, true);
+        /* we need an accurate ("real") value for the response and
+           to store into the IDS, but not for insertion into the
+           rings for example */
+        struct timespec now;
+        struct timespec queryRealTime;
+        gettime(&now);
+        gettime(&queryRealTime, true);
 
-#ifdef HAVE_DNSCRYPT
         std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
 
-        if (ci.cs->dnscryptCtx) {
-          dnsCryptQuery = std::make_shared<DNSCryptQuery>(ci.cs->dnscryptCtx);
-          uint16_t decryptedQueryLen = 0;
-          vector<uint8_t> response;
-          bool decrypted = handleDNSCryptQuery(query, qlen, dnsCryptQuery, &decryptedQueryLen, true, queryRealTime.tv_sec, response);
-
-          if (!decrypted) {
-            if (response.size() > 0) {
-              handler.writeSizeAndMsg(response.data(), response.size(), g_tcpSendTimeout);
-            }
-            break;
-          }
-          qlen = decryptedQueryLen;
+#ifdef HAVE_DNSCRYPT
+        auto dnsCryptResponse = checkDNSCryptQuery(*ci.cs, query, qlen, dnsCryptQuery, queryRealTime.tv_sec, true);
+        if (dnsCryptResponse) {
+          handler.writeSizeAndMsg(reinterpret_cast<char*>(dnsCryptResponse->data()), static_cast<uint16_t>(dnsCryptResponse->size()), g_tcpSendTimeout);
+          continue;
         }
 #endif
-        struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
 
+        struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
         if (!checkQueryHeaders(dh)) {
-          goto drop;
+          break;
         }
 
-       string poolname;
-       int delayMsec=0;
-
-       const uint16_t* flags = getFlagsFromDNSHeader(dh);
-       uint16_t origFlags = *flags;
-       uint16_t qtype, qclass;
-       unsigned int consumed = 0;
-       DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-       DNSQuestion dq(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, queryBuffer.size(), qlen, true, &queryRealTime);
+        uint16_t qtype, qclass;
+        unsigned int consumed = 0;
+        DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+        DNSQuestion dq(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, queryBuffer.size(), qlen, true, &queryRealTime);
+        dq.dnsCryptQuery = std::move(dnsCryptQuery);
 
-       if (!processQuery(holders, dq, poolname, &delayMsec, now)) {
-         goto drop;
-       }
-
-       if(dq.dh->qr) { // something turned it into a response
-          fixUpQueryTurnedResponse(dq, origFlags);
-
-          DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, dq.len, true, &queryRealTime);
-#ifdef HAVE_PROTOBUF
-          dr.uniqueId = dq.uniqueId;
-#endif
-          dr.qTag = dq.qTag;
+        responseSender sender = [&handler](const ClientState& cs, const char* data, uint16_t dataSize, int delayMsec, const ComboAddress& dest, const ComboAddress& remote) {
+          handler.writeSizeAndMsg(data, dataSize, g_tcpSendTimeout);
+        };
 
-          if (!processResponse(holders.selfAnsweredRespRulactions, dr, &delayMsec)) {
-            goto drop;
-          }
-
-#ifdef HAVE_DNSCRYPT
-          if (!encryptResponse(query, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) {
-            goto drop;
+        bool dropped = false;
+        auto ds = processQuery(dq, *ci.cs, holders, sender, dropped);
+        if (!ds) {
+          if (dropped) {
+            break;
           }
-#endif
-          handler.writeSizeAndMsg(query, dq.len, g_tcpSendTimeout);
-          ++g_stats.selfAnswered;
           continue;
         }
 
-        std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, poolname);
-        std::shared_ptr<DNSDistPacketCache> packetCache = serverPool->packetCache;
-
-        auto policy = *(holders.policy);
-        if (serverPool->policy != nullptr) {
-          policy = *(serverPool->policy);
-        }
-        auto servers = serverPool->getServers();
-        if (policy.isLua) {
-          std::lock_guard<std::mutex> lock(g_luamutex);
-          ds = policy.policy(servers, &dq);
-        }
-        else {
-          ds = policy.policy(servers, &dq);
-        }
-
-        uint32_t cacheKeyNoECS = 0;
-        uint32_t cacheKey = 0;
-        boost::optional<Netmask> subnet;
+        // check how that would work!!
         char cachedResponse[4096];
         uint16_t cachedResponseSize = sizeof cachedResponse;
-        uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL;
-        bool useZeroScope = false;
-
-        bool dnssecOK = false;
-        if (packetCache && !dq.skipCache) {
-          dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
-        }
-
-        if (dq.useECS && ((ds && ds->useECS) || (!ds && serverPool->getECS()))) {
-          // we special case our cache in case a downstream explicitly gave us a universally valid response with a 0 scope
-          if (packetCache && !dq.skipCache && (!ds || !ds->disableZeroScope) && packetCache->isECSParsingEnabled()) {
-            if (packetCache->get(dq, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKeyNoECS, subnet, dnssecOK, allowExpired)) {
-              DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, (dnsheader*) cachedResponse, sizeof cachedResponse, cachedResponseSize, true, &queryRealTime);
-#ifdef HAVE_PROTOBUF
-              dr.uniqueId = dq.uniqueId;
-#endif
-              dr.qTag = dq.qTag;
-
-              if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) {
-                goto drop;
-              }
-
-#ifdef HAVE_DNSCRYPT
-              if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery, nullptr, nullptr)) {
-                goto drop;
-              }
-#endif
-              handler.writeSizeAndMsg(cachedResponse, cachedResponseSize, g_tcpSendTimeout);
-              g_stats.cacheHits++;
-              switch (dr.dh->rcode) {
-              case RCode::NXDomain:
-                ++g_stats.frontendNXDomain;
-                break;
-              case RCode::ServFail:
-                ++g_stats.frontendServFail;
-                break;
-              case RCode::NoError:
-                ++g_stats.frontendNoError;
-                break;
-              }
-              continue;
-            }
-
-            if (!subnet) {
-              /* there was no existing ECS on the query, enable the zero-scope feature */
-              useZeroScope = true;
-            }
-          }
-
-          if (!handleEDNSClientSubnet(dq, &(ednsAdded), &(ecsAdded), g_preserveTrailingData)) {
-            vinfolog("Dropping query from %s because we couldn't insert the ECS value", ci.remote.toStringWithPort());
-            goto drop;
-          }
-        }
-
-        if (packetCache && !dq.skipCache) {
-          if (packetCache->get(dq, (uint16_t) consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, subnet, dnssecOK, allowExpired)) {
-            DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, (dnsheader*) cachedResponse, sizeof cachedResponse, cachedResponseSize, true, &queryRealTime);
-#ifdef HAVE_PROTOBUF
-            dr.uniqueId = dq.uniqueId;
-#endif
-            dr.qTag = dq.qTag;
-
-            if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) {
-              goto drop;
-            }
-
-#ifdef HAVE_DNSCRYPT
-            if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery, nullptr, nullptr)) {
-              goto drop;
-            }
-#endif
-            handler.writeSizeAndMsg(cachedResponse, cachedResponseSize, g_tcpSendTimeout);
-            ++g_stats.cacheHits;
-            switch (dr.dh->rcode) {
-            case RCode::NXDomain:
-              ++g_stats.frontendNXDomain;
-              break;
-            case RCode::ServFail:
-              ++g_stats.frontendServFail;
-              break;
-            case RCode::NoError:
-              ++g_stats.frontendNoError;
-              break;
-            }
-            continue;
-          }
-          ++g_stats.cacheMisses;
-        }
-
-        if(!ds) {
-          ++g_stats.noPolicy;
-
-          if (g_servFailOnNoPolicy) {
-            restoreFlags(dh, origFlags);
-            dq.dh->rcode = RCode::ServFail;
-            dq.dh->qr = true;
-
-            DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, dq.len, false, &queryRealTime);
-#ifdef HAVE_PROTOBUF
-            dr.uniqueId = dq.uniqueId;
-#endif
-            dr.qTag = dq.qTag;
-
-            if (!processResponse(holders.selfAnsweredRespRulactions, dr, &delayMsec)) {
-              goto drop;
-            }
-
-#ifdef HAVE_DNSCRYPT
-            if (!encryptResponse(query, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) {
-              goto drop;
-            }
-#endif
-            handler.writeSizeAndMsg(query, dq.len, g_tcpSendTimeout);
-
-            // no response-only statistics counter to update.
-            continue;
-          }
-
-          break;
-        }
-
-        if (dq.addXPF && ds->xpfRRCode != 0) {
-          addXPF(dq, ds->xpfRRCode, g_preserveTrailingData);
-        }
 
        int dsock = -1;
        uint16_t downstreamFailures=0;
@@ -580,7 +410,6 @@ void tcpClientThread(int pipefd)
 #endif /* MSG_FASTOPEN */
         }
 
-        ds->queries++;
         ds->outstanding++;
         outstanding = true;
 
@@ -641,7 +470,7 @@ void tcpClientThread(int pipefd)
           freshConn=true;
 #endif /* MSG_FASTOPEN */
           if(xfrStarted) {
-            goto drop;
+            break;
           }
           goto retry;
         }
@@ -649,7 +478,7 @@ void tcpClientThread(int pipefd)
         size_t responseSize = rlen;
         uint16_t addRoom = 0;
 #ifdef HAVE_DNSCRYPT
-        if (dnsCryptQuery && (UINT16_MAX - rlen) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) {
+        if (dq.dnsCryptQuery && (UINT16_MAX - rlen) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) {
           addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
         }
 #endif
@@ -674,7 +503,7 @@ void tcpClientThread(int pipefd)
         }
         firstPacket=false;
         bool zeroScope = false;
-        if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded, ecsAdded, rewrittenResponse, addRoom, useZeroScope ? &zeroScope : nullptr)) {
+        if (!fixUpResponse(&response, &responseLen, &responseSize, qname, dq.origFlags, dq.ednsAdded, dq.ecsAdded, rewrittenResponse, addRoom, dq.useZeroScope ? &zeroScope : nullptr)) {
           break;
         }
 
@@ -685,12 +514,12 @@ void tcpClientThread(int pipefd)
 #endif
         dr.qTag = dq.qTag;
 
-        if (!processResponse(localRespRulactions, dr, &delayMsec)) {
+        if (!processResponse(localRespRulactions, dr, &dq.delayMsec)) {
           break;
         }
 
-       if (packetCache && !dq.skipCache) {
-          if (!useZeroScope) {
+       if (dq.packetCache && !dq.skipCache) {
+          if (!dq.useZeroScope) {
             /* if the query was not suitable for zero-scope, for
                example because it had an existing ECS entry so the hash is
                not really 'no ECS', so just insert it for the existing subnet
@@ -702,12 +531,12 @@ void tcpClientThread(int pipefd)
             zeroScope = false;
           }
           // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
-          packetCache->insert(zeroScope ? cacheKeyNoECS : cacheKey, zeroScope ? boost::none : subnet, origFlags, dnssecOK, qname, qtype, qclass, response, responseLen, true, dh->rcode, dq.tempFailureTTL);
+          dq.packetCache->insert(zeroScope ? dq.cacheKeyNoECS : dq.cacheKey, zeroScope ? boost::none : dq.subnet, dq.origFlags, dq.dnssecOK, qname, qtype, qclass, response, responseLen, true, dh->rcode, dq.tempFailureTTL);
        }
 
 #ifdef HAVE_DNSCRYPT
-        if (!encryptResponse(response, &responseLen, responseSize, true, dnsCryptQuery, &dh, &dhCopy)) {
-          goto drop;
+        if (!encryptResponse(response, &responseLen, responseSize, true, dq.dnsCryptQuery, &dh, &dhCopy)) {
+          break;
         }
 #endif
         if (!handler.writeSizeAndMsg(response, responseLen, g_tcpSendTimeout)) {
@@ -752,9 +581,11 @@ void tcpClientThread(int pipefd)
         rewrittenResponse.clear();
       }
     }
-    catch(...) {}
-
-  drop:;
+    catch(const std::exception& e) {
+      vinfolog("Got exception while handling TCP query: %s", e.what());
+    }
+    catch(...) {
+    }
 
     vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort());
 
index de8c021981165ccd98784b1ae851e8db3be15c9a..5d6d8850d31f6dd5cba05c8777279b65883bcd15 100644 (file)
@@ -249,7 +249,7 @@ bool responseContentMatches(const char* response, const uint16_t responseLen, co
   return true;
 }
 
-void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
+static void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
 {
   static const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
   static const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
@@ -263,7 +263,7 @@ void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
   *flags |= origFlags;
 }
 
-bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
+static bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
 {
   restoreFlags(dq.dh, origFlags);
 
@@ -392,7 +392,7 @@ bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize,
 }
 #endif
 
-static bool sendUDPResponse(int origFD, char* response, uint16_t responseLen, int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
+static bool sendUDPResponse(int origFD, const char* response, uint16_t responseLen, int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
 {
   if(delayMsec && g_delay) {
     DelayedPacket dp{origFD, string(response,responseLen), origRemote, origDest};
@@ -416,7 +416,7 @@ static bool sendUDPResponse(int origFD, char* response, uint16_t responseLen, in
 }
 
 
-static int pickBackendSocketForSending(DownstreamState* state)
+static int pickBackendSocketForSending(std::shared_ptr<DownstreamState>& state)
 {
   return state->sockets[state->socketsOffset++ % state->sockets.size()];
 }
@@ -971,9 +971,9 @@ static void spoofResponseFromString(DNSQuestion& dq, const string& spoofContent)
   }
 }
 
-bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int* delayMsec, const struct timespec& now)
+static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, const struct timespec& now)
 {
-  g_rings.insertQuery(now,*dq.remote,*dq.qname,dq.qtype,dq.len,*dq.dh);
+  g_rings.insertQuery(now, *dq.remote, *dq.qname, dq.qtype, dq.len, *dq.dh);
 
   if(g_qcount.enabled) {
     string qname = (*dq.qname).toString(".");
@@ -1155,7 +1155,7 @@ bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int*
         break;
         /* non-terminal actions follow */
       case DNSAction::Action::Delay:
-        *delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
+        dq.delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
         break;
       case DNSAction::Action::None:
         /* fall-through */
@@ -1207,7 +1207,7 @@ bool processResponse(LocalStateHolder<vector<DNSDistResponseRuleAction> >& local
   return true;
 }
 
-static ssize_t udpClientSendRequestToBackend(DownstreamState* ss, const int sd, const char* request, const size_t requestLen, bool healthCheck=false)
+static ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const char* request, const size_t requestLen, bool healthCheck=false)
 {
   ssize_t result;
 
@@ -1271,7 +1271,7 @@ static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const s
 }
 
 #ifdef HAVE_DNSCRYPT
-static bool checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, const ComboAddress& dest, const ComboAddress& remote, time_t now)
+boost::optional<std::vector<uint8_t>> checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, time_t now, bool tcp)
 {
   if (cs.dnscryptCtx) {
     vector<uint8_t> response;
@@ -1279,18 +1279,18 @@ static bool checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_
 
     dnsCryptQuery = std::make_shared<DNSCryptQuery>(cs.dnscryptCtx);
 
-    bool decrypted = handleDNSCryptQuery(const_cast<char*>(query), len, dnsCryptQuery, &decryptedQueryLen, false, now, response);
+    bool decrypted = handleDNSCryptQuery(const_cast<char*>(query), len, dnsCryptQuery, &decryptedQueryLen, tcp, now, response);
 
     if (!decrypted) {
       if (response.size() > 0) {
-        sendUDPResponse(cs.udpFD, reinterpret_cast<char*>(response.data()), static_cast<uint16_t>(response.size()), 0, dest, remote);
+        return response;
       }
-      return false;
+      throw std::runtime_error("Unable to decrypt DNSCrypt query, dropping.");
     }
 
     len = decryptedQueryLen;
   }
-  return true;
+  return boost::none;
 }
 #endif /* HAVE_DNSCRYPT */
 
@@ -1328,34 +1328,27 @@ static void queueResponse(const ClientState& cs, const char* response, uint16_t
 }
 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
 
-static int sendAndEncryptUDPResponse(LocalHolders& holders, ClientState& cs, const DNSQuestion& dq, char* response, uint16_t responseLen, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, int delayMsec, const ComboAddress& dest, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf, bool cacheHit)
+static int sendResponse(LocalHolders& holders, ClientState& cs, DNSQuestion& dq, char* response, uint16_t responseLen, bool cacheHit, responseSender sender)
 {
-  DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(response), dq.size, responseLen, false, dq.queryTime);
+  DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(response), dq.size, responseLen, dq.tcp, dq.queryTime);
+
 #ifdef HAVE_PROTOBUF
   dr.uniqueId = dq.uniqueId;
 #endif
   dr.qTag = dq.qTag;
 
-  if (!processResponse(cacheHit ? holders.cacheHitRespRulactions : holders.selfAnsweredRespRulactions, dr, &delayMsec)) {
+  if (!processResponse(cacheHit ? holders.cacheHitRespRulactions : holders.selfAnsweredRespRulactions, dr, &dq.delayMsec)) {
     return -1;
   }
 
   if (!cs.muted) {
 #ifdef HAVE_DNSCRYPT
-    if (!encryptResponse(response, &responseLen, dq.size, false, dnsCryptQuery, nullptr, nullptr)) {
+    if (!encryptResponse(response, &responseLen, dq.size, dq.tcp, dq.dnsCryptQuery, nullptr, nullptr)) {
       return -1;
     }
 #endif
-#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
-    if (delayMsec == 0 && responsesVect != nullptr) {
-      queueResponse(cs, response, responseLen, dest, *dq.remote, responsesVect[*queuedResponses], respIOV, respCBuf);
-      (*queuedResponses)++;
-    }
-    else
-#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
-      {
-        sendUDPResponse(cs.udpFD, response, responseLen, delayMsec, dest, *dq.remote);
-      }
+
+    sender(cs, response, responseLen, dq.delayMsec, *dq.local, *dq.remote);
   }
 
   if (cacheHit) {
@@ -1376,72 +1369,43 @@ static int sendAndEncryptUDPResponse(LocalHolders& holders, ClientState& cs, con
   return 0;
 }
 
-static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf)
+/* returns nullptr if the query has been taken care of (cache-hit, self-answered or discarded) and a backend it should be sent to otherwise */
+std::shared_ptr<DownstreamState> processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, responseSender sender, bool& dropped)
 {
-  assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr));
-  uint16_t queryId = 0;
+  const uint16_t queryId = ntohs(dq.dh->id);
 
   try {
-    if (!isUDPQueryAcceptable(cs, holders, msgh, remote, dest)) {
-      return;
-    }
-
     /* we need an accurate ("real") value for the response and
        to store into the IDS, but not for insertion into the
        rings for example */
-    struct timespec queryRealTime;
     struct timespec now;
     gettime(&now);
-    gettime(&queryRealTime, true);
-
-    std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
-
-#ifdef HAVE_DNSCRYPT
-    if (!checkDNSCryptQuery(cs, query, len, dnsCryptQuery, dest, remote, queryRealTime.tv_sec)) {
-      return;
-    }
-#endif
-
-    struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
-    queryId = ntohs(dh->id);
-
-    if (!checkQueryHeaders(dh)) {
-      return;
-    }
 
     string poolname;
-    int delayMsec = 0;
-    const uint16_t * flags = getFlagsFromDNSHeader(dh);
-    const uint16_t origFlags = *flags;
-    uint16_t qtype, qclass;
-    unsigned int consumed = 0;
-    DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-    DNSQuestion dq(&qname, qtype, qclass, consumed, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false, &queryRealTime);
-    bool dnssecOK = false;
 
-    if (!processQuery(holders, dq, poolname, &delayMsec, now))
-    {
-      return;
+    if (!applyRulesToQuery(holders, dq, poolname, now)) {
+      dropped = true;
+      return nullptr;
     }
 
     if(dq.dh->qr) { // something turned it into a response
-      fixUpQueryTurnedResponse(dq, origFlags);
+      fixUpQueryTurnedResponse(dq, dq.origFlags);
 
       if (!cs.muted) {
-        char* response = query;
+        char* response = reinterpret_cast<char*>(dq.dh);
         uint16_t responseLen = dq.len;
 
-        sendAndEncryptUDPResponse(holders, cs, dq, response, responseLen, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, false);
+        sendResponse(holders, cs, dq, response, responseLen, false, sender);
 
         ++g_stats.selfAnswered;
       }
 
-      return;
+      return nullptr;
     }
 
-    DownstreamState* ss = nullptr;
+    std::shared_ptr<DownstreamState> ss{nullptr};
     std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, poolname);
-    std::shared_ptr<DNSDistPacketCache> packetCache = serverPool->packetCache;
+    dq.packetCache = serverPool->packetCache;
     auto policy = *(holders.policy);
     if (serverPool->policy != nullptr) {
       policy = *(serverPool->policy);
@@ -1449,49 +1413,44 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     auto servers = serverPool->getServers();
     if (policy.isLua) {
       std::lock_guard<std::mutex> lock(g_luamutex);
-      ss = policy.policy(servers, &dq).get();
+      ss = policy.policy(servers, &dq);
     }
     else {
-      ss = policy.policy(servers, &dq).get();
+      ss = policy.policy(servers, &dq);
     }
 
-    bool ednsAdded = false;
-    bool ecsAdded = false;
-    uint32_t cacheKeyNoECS = 0;
-    uint32_t cacheKey = 0;
-    boost::optional<Netmask> subnet;
     uint16_t cachedResponseSize = dq.size;
     uint32_t allowExpired = ss ? 0 : g_staleCacheEntriesTTL;
-    bool useZeroScope = false;
 
-    if (packetCache && !dq.skipCache) {
-      dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
+    if (dq.packetCache && !dq.skipCache) {
+      dq.dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
     }
 
     if (dq.useECS && ((ss && ss->useECS) || (!ss && serverPool->getECS()))) {
       // we special case our cache in case a downstream explicitly gave us a universally valid response with a 0 scope
-      if (packetCache && !dq.skipCache && (!ss || !ss->disableZeroScope) && packetCache->isECSParsingEnabled()) {
-        if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKeyNoECS, subnet, dnssecOK, allowExpired)) {
-          sendAndEncryptUDPResponse(holders, cs, dq, query, cachedResponseSize, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, true);
-          return;
+      if (dq.packetCache && !dq.skipCache && (!ss || !ss->disableZeroScope) && dq.packetCache->isECSParsingEnabled()) {
+        if (dq.packetCache->get(dq, dq.consumed, dq.dh->id, reinterpret_cast<char*>(dq.dh), &cachedResponseSize, &dq.cacheKeyNoECS, dq.subnet, dq.dnssecOK, allowExpired)) {
+          sendResponse(holders, cs, dq, reinterpret_cast<char*>(dq.dh), cachedResponseSize, true, sender);
+          return nullptr;
         }
 
-        if (!subnet) {
+        if (!dq.subnet) {
           /* there was no existing ECS on the query, enable the zero-scope feature */
-          useZeroScope = true;
+          dq.useZeroScope = true;
         }
       }
 
-      if (!handleEDNSClientSubnet(dq, &(ednsAdded), &(ecsAdded), g_preserveTrailingData)) {
-        vinfolog("Dropping query from %s because we couldn't insert the ECS value", remote.toStringWithPort());
-        return;
+      if (!handleEDNSClientSubnet(dq, &(dq.ednsAdded), &(dq.ecsAdded), g_preserveTrailingData)) {
+        vinfolog("Dropping query from %s because we couldn't insert the ECS value", dq.remote->toStringWithPort());
+        dropped = true;
+        return nullptr;
       }
     }
 
-    if (packetCache && !dq.skipCache) {
-      if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKey, subnet, dnssecOK, allowExpired)) {
-        sendAndEncryptUDPResponse(holders, cs, dq, query, cachedResponseSize, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, true);
-        return;
+    if (dq.packetCache && !dq.skipCache) {
+      if (dq.packetCache->get(dq, dq.consumed, dq.dh->id, reinterpret_cast<char*>(dq.dh), &cachedResponseSize, &dq.cacheKey, dq.subnet, dq.dnssecOK, allowExpired)) {
+        sendResponse(holders, cs, dq, reinterpret_cast<char*>(dq.dh), cachedResponseSize, true, sender);
+        return nullptr;
       }
       ++g_stats.cacheMisses;
     }
@@ -1500,19 +1459,18 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       ++g_stats.noPolicy;
 
       if (g_servFailOnNoPolicy && !cs.muted) {
-        char* response = query;
+        char* response = reinterpret_cast<char*>(dq.dh);
         uint16_t responseLen = dq.len;
-        restoreFlags(dh, origFlags);
+        restoreFlags(dq.dh, dq.origFlags);
 
         dq.dh->rcode = RCode::ServFail;
         dq.dh->qr = true;
 
-        sendAndEncryptUDPResponse(holders, cs, dq, response, responseLen, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, false);
-
+        sendResponse(holders, cs, dq, response, responseLen, false, sender);
         // no response-only statistics counter to update.
       }
-      vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toString(), QType(dq.qtype).getName(), remote.toStringWithPort());
-      return;
+      vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toString(), QType(dq.qtype).getName(), dq.remote->toStringWithPort());
+      return nullptr;
     }
 
     if (dq.addXPF && ss->xpfRRCode != 0) {
@@ -1520,6 +1478,73 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     }
 
     ss->queries++;
+    return ss;
+  }
+  catch(const std::exception& e){
+    vinfolog("Got an error while parsing a %s query from %s, id %d: %s", (dq.tcp ? "TCP" : "UDP"), dq.remote->toStringWithPort(), queryId, e.what());
+    dropped = true;
+  }
+  return nullptr;
+}
+
+static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf)
+{
+  assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr));
+  uint16_t queryId = 0;
+
+  try {
+    if (!isUDPQueryAcceptable(cs, holders, msgh, remote, dest)) {
+      return;
+    }
+
+    /* we need an accurate ("real") value for the response and
+       to store into the IDS, but not for insertion into the
+       rings for example */
+    struct timespec queryRealTime;
+    struct timespec now;
+    gettime(&now);
+    gettime(&queryRealTime, true);
+
+    std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
+
+#ifdef HAVE_DNSCRYPT
+    auto dnsCryptResponse = checkDNSCryptQuery(cs, query, len, dnsCryptQuery, queryRealTime.tv_sec, false);
+    if (dnsCryptResponse) {
+      sendUDPResponse(cs.udpFD, reinterpret_cast<char*>(dnsCryptResponse->data()), static_cast<uint16_t>(dnsCryptResponse->size()), 0, dest, remote);
+      return;
+    }
+#endif
+
+    struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
+    queryId = ntohs(dh->id);
+
+    if (!checkQueryHeaders(dh)) {
+      return;
+    }
+
+    uint16_t qtype, qclass;
+    unsigned int consumed = 0;
+    DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+    DNSQuestion dq(&qname, qtype, qclass, consumed, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false, &queryRealTime);
+    dq.dnsCryptQuery = std::move(dnsCryptQuery);
+
+    responseSender sender = [&responsesVect, &queuedResponses, &respIOV, &respCBuf](const ClientState& cs, const char* data, uint16_t dataSize, int delayMsec, const ComboAddress& dest, const ComboAddress& remote) -> void {
+#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
+      if (delayMsec == 0 && responsesVect != nullptr) {
+        queueResponse(cs, data, dataSize, dest, remote, responsesVect[*queuedResponses], respIOV, respCBuf);
+        (*queuedResponses)++;
+        return;
+      }
+#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
+      sendUDPResponse(cs.udpFD, data, dataSize, delayMsec, dest, remote);
+    };
+
+    bool dropped = false;
+    auto ss = processQuery(dq, cs, holders, sender, dropped);
+
+    if (!ss) {
+      return;
+    }
 
     unsigned int idOffset = (ss->idOffset++) % ss->idStates.size();
     IDState* ids = &ss->idStates[idOffset];
@@ -1539,22 +1564,22 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     ids->origID = dh->id;
     ids->origRemote = remote;
     ids->sentTime.set(queryRealTime);
-    ids->qname = qname;
+    ids->qname = std::move(qname);
     ids->qtype = dq.qtype;
     ids->qclass = dq.qclass;
-    ids->delayMsec = delayMsec;
+    ids->delayMsec = dq.delayMsec;
     ids->tempFailureTTL = dq.tempFailureTTL;
-    ids->origFlags = origFlags;
-    ids->cacheKey = cacheKey;
-    ids->cacheKeyNoECS = cacheKeyNoECS;
-    ids->subnet = subnet;
+    ids->origFlags = dq.origFlags;
+    ids->cacheKey = dq.cacheKey;
+    ids->cacheKeyNoECS = dq.cacheKeyNoECS;
+    ids->subnet = dq.subnet;
     ids->skipCache = dq.skipCache;
-    ids->packetCache = packetCache;
-    ids->ednsAdded = ednsAdded;
-    ids->ecsAdded = ecsAdded;
-    ids->useZeroScope = useZeroScope;
+    ids->packetCache = dq.packetCache;
+    ids->ednsAdded = dq.ednsAdded;
+    ids->ecsAdded = dq.ecsAdded;
+    ids->useZeroScope = dq.useZeroScope;
     ids->qTag = dq.qTag;
-    ids->dnssecOK = dnssecOK;
+    ids->dnssecOK = dq.dnssecOK;
 
     /* If we couldn't harvest the real dest addr, still
        write down the listening addr since it will be useful
@@ -1571,10 +1596,10 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       ids->destHarvested = false;
     }
 #ifdef HAVE_DNSCRYPT
-    ids->dnsCryptQuery = dnsCryptQuery;
+    ids->dnsCryptQuery = std::move(dq.dnsCryptQuery);
 #endif
 #ifdef HAVE_PROTOBUF
-    ids->uniqueId = dq.uniqueId;
+    ids->uniqueId = std::move(dq.uniqueId);
 #endif
 
     dh->id = idOffset;
@@ -1741,12 +1766,12 @@ uint16_t getRandomDNSID()
 #endif
 }
 
-static bool upCheck(DownstreamState& ds)
+static bool upCheck(const shared_ptr<DownstreamState>& ds)
 try
 {
-  DNSName checkName = ds.checkName;
-  uint16_t checkType = ds.checkType.getCode();
-  uint16_t checkClass = ds.checkClass;
+  DNSName checkName = ds->checkName;
+  uint16_t checkType = ds->checkType.getCode();
+  uint16_t checkClass = ds->checkClass;
   dnsheader checkHeader;
   memset(&checkHeader, 0, sizeof(checkHeader));
 
@@ -1754,13 +1779,13 @@ try
   checkHeader.id = getRandomDNSID();
 
   checkHeader.rd = true;
-  if (ds.setCD) {
+  if (ds->setCD) {
     checkHeader.cd = true;
   }
 
-  if (ds.checkFunction) {
+  if (ds->checkFunction) {
     std::lock_guard<std::mutex> lock(g_luamutex);
-    auto ret = ds.checkFunction(checkName, checkType, checkClass, &checkHeader);
+    auto ret = ds->checkFunction(checkName, checkType, checkClass, &checkHeader);
     checkName = std::get<0>(ret);
     checkType = std::get<1>(ret);
     checkClass = std::get<2>(ret);
@@ -1771,31 +1796,31 @@ try
   dnsheader * requestHeader = dpw.getHeader();
   *requestHeader = checkHeader;
 
-  Socket sock(ds.remote.sin4.sin_family, SOCK_DGRAM);
+  Socket sock(ds->remote.sin4.sin_family, SOCK_DGRAM);
   sock.setNonBlocking();
-  if (!IsAnyAddress(ds.sourceAddr)) {
+  if (!IsAnyAddress(ds->sourceAddr)) {
     sock.setReuseAddr();
-    sock.bind(ds.sourceAddr);
+    sock.bind(ds->sourceAddr);
   }
-  sock.connect(ds.remote);
-  ssize_t sent = udpClientSendRequestToBackend(&ds, sock.getHandle(), (char*)&packet[0], packet.size(), true);
+  sock.connect(ds->remote);
+  ssize_t sent = udpClientSendRequestToBackend(ds, sock.getHandle(), (char*)&packet[0], packet.size(), true);
   if (sent < 0) {
     int ret = errno;
     if (g_verboseHealthChecks)
-      infolog("Error while sending a health check query to backend %s: %d", ds.getNameWithAddr(), ret);
+      infolog("Error while sending a health check query to backend %s: %d", ds->getNameWithAddr(), ret);
     return false;
   }
 
-  int ret = waitForRWData(sock.getHandle(), true, /* ms to seconds */ ds.checkTimeout / 1000, /* remaining ms to us */ (ds.checkTimeout % 1000) * 1000);
+  int ret = waitForRWData(sock.getHandle(), true, /* ms to seconds */ ds->checkTimeout / 1000, /* remaining ms to us */ (ds->checkTimeout % 1000) * 1000);
   if(ret < 0 || !ret) { // error, timeout, both are down!
     if (ret < 0) {
       ret = errno;
       if (g_verboseHealthChecks)
-        infolog("Error while waiting for the health check response from backend %s: %d", ds.getNameWithAddr(), ret);
+        infolog("Error while waiting for the health check response from backend %s: %d", ds->getNameWithAddr(), ret);
     }
     else {
       if (g_verboseHealthChecks)
-        infolog("Timeout while waiting for the health check response from backend %s", ds.getNameWithAddr());
+        infolog("Timeout while waiting for the health check response from backend %s", ds->getNameWithAddr());
     }
     return false;
   }
@@ -1805,9 +1830,9 @@ try
   sock.recvFrom(reply, from);
 
   /* we are using a connected socket but hey.. */
-  if (from != ds.remote) {
+  if (from != ds->remote) {
     if (g_verboseHealthChecks)
-      infolog("Invalid health check response received from %s, expecting one from %s", from.toStringWithPort(), ds.remote.toStringWithPort());
+      infolog("Invalid health check response received from %s, expecting one from %s", from.toStringWithPort(), ds->remote.toStringWithPort());
     return false;
   }
 
@@ -1815,31 +1840,31 @@ try
 
   if (reply.size() < sizeof(*responseHeader)) {
     if (g_verboseHealthChecks)
-      infolog("Invalid health check response of size %d from backend %s, expecting at least %d", reply.size(), ds.getNameWithAddr(), sizeof(*responseHeader));
+      infolog("Invalid health check response of size %d from backend %s, expecting at least %d", reply.size(), ds->getNameWithAddr(), sizeof(*responseHeader));
     return false;
   }
 
   if (responseHeader->id != requestHeader->id) {
     if (g_verboseHealthChecks)
-      infolog("Invalid health check response id %d from backend %s, expecting %d", responseHeader->id, ds.getNameWithAddr(), requestHeader->id);
+      infolog("Invalid health check response id %d from backend %s, expecting %d", responseHeader->id, ds->getNameWithAddr(), requestHeader->id);
     return false;
   }
 
   if (!responseHeader->qr) {
     if (g_verboseHealthChecks)
-      infolog("Invalid health check response from backend %s, expecting QR to be set", ds.getNameWithAddr());
+      infolog("Invalid health check response from backend %s, expecting QR to be set", ds->getNameWithAddr());
     return false;
   }
 
   if (responseHeader->rcode == RCode::ServFail) {
     if (g_verboseHealthChecks)
-      infolog("Backend %s responded to health check with ServFail", ds.getNameWithAddr());
+      infolog("Backend %s responded to health check with ServFail", ds->getNameWithAddr());
     return false;
   }
 
-  if (ds.mustResolve && (responseHeader->rcode == RCode::NXDomain || responseHeader->rcode == RCode::Refused)) {
+  if (ds->mustResolve && (responseHeader->rcode == RCode::NXDomain || responseHeader->rcode == RCode::Refused)) {
     if (g_verboseHealthChecks)
-      infolog("Backend %s responded to health check with %s while mustResolve is set", ds.getNameWithAddr(), responseHeader->rcode == RCode::NXDomain ? "NXDomain" : "Refused");
+      infolog("Backend %s responded to health check with %s while mustResolve is set", ds->getNameWithAddr(), responseHeader->rcode == RCode::NXDomain ? "NXDomain" : "Refused");
     return false;
   }
 
@@ -1849,7 +1874,7 @@ try
 
   if (receivedName != checkName || receivedType != checkType || receivedClass != checkClass) {
     if (g_verboseHealthChecks)
-      infolog("Backend %s responded to health check with an invalid qname (%s vs %s), qtype (%s vs %s) or qclass (%d vs %d)", ds.getNameWithAddr(), receivedName.toLogString(), checkName.toLogString(), QType(receivedType).getName(), QType(checkType).getName(), receivedClass, checkClass);
+      infolog("Backend %s responded to health check with an invalid qname (%s vs %s), qtype (%s vs %s) or qclass (%d vs %d)", ds->getNameWithAddr(), receivedName.toLogString(), checkName.toLogString(), QType(receivedType).getName(), QType(checkType).getName(), receivedClass, checkClass);
     return false;
   }
 
@@ -1858,13 +1883,13 @@ try
 catch(const std::exception& e)
 {
   if (g_verboseHealthChecks)
-    infolog("Error checking the health of backend %s: %s", ds.getNameWithAddr(), e.what());
+    infolog("Error checking the health of backend %s: %s", ds->getNameWithAddr(), e.what());
   return false;
 }
 catch(...)
 {
   if (g_verboseHealthChecks)
-    infolog("Unknown exception while checking the health of backend %s", ds.getNameWithAddr());
+    infolog("Unknown exception while checking the health of backend %s", ds->getNameWithAddr());
   return false;
 }
 
@@ -1978,7 +2003,7 @@ static void healthChecksThread()
         continue;
       dss->lastCheck = 0;
       if(dss->availability==DownstreamState::Availability::Auto) {
-        bool newState=upCheck(*dss);
+        bool newState=upCheck(dss);
         if (newState) {
           /* check succeeded */
           dss->currentCheckFailures = 0;
@@ -2826,7 +2851,7 @@ try
 
   for(auto& dss : g_dstates.getCopy()) { // it is a copy, but the internal shared_ptrs are the real deal
     if(dss->availability==DownstreamState::Availability::Auto) {
-      bool newState=upCheck(*dss);
+      bool newState=upCheck(dss);
       warnlog("Marking downstream %s as '%s'", dss->getNameWithAddr(), newState ? "up" : "down");
       dss->upStatus = newState;
     }
index 8f33854d88bc05e4b701d0f9703ec3924a710c42..137f250bc69b0de769e1834424d283cdd6aabfd1 100644 (file)
@@ -61,33 +61,47 @@ typedef std::unordered_map<string, string> QTag;
 struct DNSQuestion
 {
   DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, unsigned int consumed_, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t queryLen, bool isTcp, const struct timespec* queryTime_):
-    qname(name), qtype(type), qclass(class_), local(lc), remote(rem), dh(header), size(bufferSize), consumed(consumed_), len(queryLen), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), tempFailureTTL(boost::none), tcp(isTcp), queryTime(queryTime_), ecsOverride(g_ECSOverride) { }
+    qname(name), local(lc), remote(rem), dh(header), queryTime(queryTime_), size(bufferSize), consumed(consumed_), tempFailureTTL(boost::none), qtype(type), qclass(class_), len(queryLen), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), tcp(isTcp), ecsOverride(g_ECSOverride) {
+    const uint16_t* flags = getFlagsFromDNSHeader(dh);
+    origFlags = *flags;
+  }
 
 #ifdef HAVE_PROTOBUF
   boost::optional<boost::uuids::uuid> uniqueId;
 #endif
   Netmask ecs;
-  const DNSName* qname;
-  const uint16_t qtype;
-  const uint16_t qclass;
-  const ComboAddress* local;
-  const ComboAddress* remote;
+  boost::optional<Netmask> subnet;
+  const DNSName* qname{nullptr};
+  const ComboAddress* local{nullptr};
+  const ComboAddress* remote{nullptr};
   std::shared_ptr<QTag> qTag{nullptr};
   std::shared_ptr<std::map<uint16_t, EDNSOptionView> > ednsOptions;
-  struct dnsheader* dh;
+  std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
+  std::shared_ptr<DNSDistPacketCache> packetCache{nullptr};
+  struct dnsheader* dh{nullptr};
+  const struct timespec* queryTime{nullptr};
   size_t size;
   unsigned int consumed{0};
+  int delayMsec{0};
+  boost::optional<uint32_t> tempFailureTTL;
+  uint32_t cacheKeyNoECS;
+  uint32_t cacheKey;
+  const uint16_t qtype;
+  const uint16_t qclass;
   uint16_t len;
   uint16_t ecsPrefixLength;
+  uint16_t origFlags;
   uint8_t ednsRCode{0};
-  boost::optional<uint32_t> tempFailureTTL;
   const bool tcp;
-  const struct timespec* queryTime;
   bool skipCache{false};
   bool ecsOverride;
   bool useECS{true};
   bool addXPF{true};
   bool ecsSet{false};
+  bool ecsAdded{false};
+  bool ednsAdded{false};
+  bool useZeroScope{false};
+  bool dnssecOK{false};
 };
 
 struct DNSResponse : DNSQuestion
@@ -1032,11 +1046,9 @@ bool getLuaNoSideEffect(); // set if there were only explicit declarations of _n
 void resetLuaSideEffect(); // reset to indeterminate state
 
 bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& consumed);
-bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int* delayMsec, const struct timespec& now);
 bool processResponse(LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, int* delayMsec);
-bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags);
 bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom, bool* zeroScope);
-void restoreFlags(struct dnsheader* dh, uint16_t origFlags);
+
 bool checkQueryHeaders(const struct dnsheader* dh);
 
 #ifdef HAVE_DNSCRYPT
@@ -1044,6 +1056,8 @@ extern std::vector<std::tuple<ComboAddress, std::shared_ptr<DNSCryptContext>, bo
 
 bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy);
 int handleDNSCryptQuery(char* packet, uint16_t len, std::shared_ptr<DNSCryptQuery> query, uint16_t* decryptedQueryLen, bool tcp, time_t now, std::vector<uint8_t>& response);
+
+boost::optional<std::vector<uint8_t>> checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, time_t now, bool tcp);
 #endif
 
 bool addXPF(DNSQuestion& dq, uint16_t optionCode);
@@ -1058,3 +1072,7 @@ extern DNSDistSNMPAgent* g_snmpAgent;
 extern bool g_addEDNSToSelfGeneratedResponses;
 
 static const size_t s_udpIncomingBufferSize{1500};
+
+typedef std::function<void(const ClientState& cs, const char* data, uint16_t dataSize, int delayMsec, const ComboAddress& dest, const ComboAddress& remote)> responseSender;
+std::shared_ptr<DownstreamState> processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, responseSender sender, bool& dropped);
+
index cd2a22158d5044d7e706d1240bd62fc3717d8231..357c902db74837e5f9313f997f965fd2ed6ab425 100644 (file)
@@ -22,7 +22,8 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) {
   uint16_t qclass = QClass::IN;
   ComboAddress lc("127.0.0.1:53");
   ComboAddress rem("192.0.2.1:42");
-  struct dnsheader* dh = nullptr;
+  struct dnsheader dh;
+  memset(&dh, 0, sizeof(dh));
   size_t bufferSize = 0;
   size_t queryLen = 0;
   bool isTcp = false;
@@ -32,7 +33,7 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) {
   /* the internal QPS limiter does not use the real time */
   gettime(&expiredTime);
 
-  DNSQuestion dq(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, dh, bufferSize, queryLen, isTcp, &queryRealTime);
+  DNSQuestion dq(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, &dh, bufferSize, queryLen, isTcp, &queryRealTime);
 
   for (size_t idx = 0; idx < maxQPS; idx++) {
     /* let's use different source ports, it shouldn't matter */