]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: First working version of cross-protocol DoH -> TCP
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 31 Mar 2021 15:22:21 +0000 (17:22 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 26 Aug 2021 14:30:25 +0000 (16:30 +0200)
29 files changed:
pdns/Makefile.am
pdns/dnsdist-console.cc
pdns/dnsdist-idstate.hh [new file with mode: 0644]
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdist-lua-inspection.cc
pdns/dnsdist-lua-rules.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-protobuf.cc
pdns/dnsdist-protocols.hh [new file with mode: 0644]
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-idstate.hh [new symlink]
pdns/dnsdistdist/dnsdist-protocols.cc [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-protocols.hh [new symlink]
pdns/dnsdistdist/dnsdist-tcp-downstream.cc
pdns/dnsdistdist/dnsdist-tcp-downstream.hh
pdns/dnsdistdist/dnsdist-tcp-upstream.hh
pdns/dnsdistdist/dnsdist-tcp.hh [new file with mode: 0644]
pdns/dnsdistdist/doh.cc
pdns/dnsdistdist/test-dnsdistkvs_cc.cc
pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc
pdns/dnsdistdist/test-dnsdistrules_cc.cc
pdns/dnsdistdist/test-dnsdisttcp_cc.cc
pdns/doh.hh
pdns/test-dnsdist_cc.cc
pdns/test-dnsdistpacketcache_cc.cc

index 137de97c83cf426f04727b355e3fa5a05f7a89d4..a17aa95e0aeae8a5c3cdfce2fdeb33ebb0063e9a 100644 (file)
@@ -1562,6 +1562,8 @@ fuzz_target_dnsdistcache_SOURCES = \
        dns.cc dns.hh \
        dnsdist-cache.cc dnsdist-cache.hh \
        dnsdist-ecs.cc dnsdist-ecs.hh \
+       dnsdist-idstate.hh \
+       dnsdist-protocols.hh \
        dnslabeltext.cc \
        dnsname.cc dnsname.hh \
        dnsparser.cc dnsparser.hh \
index 2ab26821c5194df4ffcb18517d7116a6ab881874..4a5599519de455287531b48e525c7ca6920250f0 100644 (file)
@@ -622,7 +622,6 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "setSyslogFacility", true, "facility", "set the syslog logging facility to 'facility'. Defaults to LOG_DAEMON" },
   { "setTCPDownstreamCleanupInterval", true, "interval", "minimum interval in seconds between two cleanups of the idle TCP downstream connections" },
   { "setTCPInternalPipeBufferSize", true, "size", "Set the size in bytes of the internal buffer of the pipes used internally to distribute connections to TCP (and DoT) workers threads" },
-  { "setTCPUseSinglePipe", true, "bool", "whether the incoming TCP connections should be put into a single queue instead of using per-thread queues. Defaults to false" },
   { "setTCPRecvTimeout", true, "n", "set the read timeout on TCP connections from the client, in seconds" },
   { "setTCPSendTimeout", true, "n", "set the write timeout on TCP connections from the client, in seconds" },
   { "setUDPMultipleMessagesVectorSize", true, "n", "set the size of the vector passed to recvmmsg() to receive UDP messages. Default to 1 which means that the feature is disabled and recvmsg() is used instead" },
diff --git a/pdns/dnsdist-idstate.hh b/pdns/dnsdist-idstate.hh
new file mode 100644 (file)
index 0000000..69d882b
--- /dev/null
@@ -0,0 +1,258 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include "config.h"
+#include "dnsname.hh"
+#include "dnsdist-protocols.hh"
+#include "gettime.hh"
+#include "iputils.hh"
+#include "uuid-utils.hh"
+
+struct ClientState;
+struct DOHUnit;
+class DNSCryptQuery;
+class DNSDistPacketCache;
+
+using QTag = std::unordered_map<string, string>;
+
+struct StopWatch
+{
+  StopWatch(bool realTime=false): d_needRealTime(realTime)
+  {
+  }
+
+  void start() {
+    if (gettime(&d_start, d_needRealTime) < 0) {
+      unixDie("Getting timestamp");
+    }
+  }
+
+  void set(const struct timespec& from) {
+    d_start = from;
+  }
+
+  double udiff() const {
+    struct timespec now;
+    if (gettime(&now, d_needRealTime) < 0) {
+      unixDie("Getting timestamp");
+    }
+
+    return 1000000.0*(now.tv_sec - d_start.tv_sec) + (now.tv_nsec - d_start.tv_nsec)/1000.0;
+  }
+
+  double udiffAndSet() {
+    struct timespec now;
+    if (gettime(&now, d_needRealTime) < 0) {
+      unixDie("Getting timestamp");
+    }
+
+    auto ret= 1000000.0*(now.tv_sec - d_start.tv_sec) + (now.tv_nsec - d_start.tv_nsec)/1000.0;
+    d_start = now;
+    return ret;
+  }
+
+  struct timespec d_start{0,0};
+private:
+  bool d_needRealTime{false};
+};
+
+/* g++ defines __SANITIZE_THREAD__
+   clang++ supports the nice __has_feature(thread_sanitizer),
+   let's merge them */
+#if defined(__has_feature)
+#if __has_feature(thread_sanitizer)
+#define __SANITIZE_THREAD__ 1
+#endif
+#endif
+
+struct IDState
+{
+  IDState(): sentTime(true), tempFailureTTL(boost::none) { origDest.sin4.sin_family = 0;}
+  IDState(const IDState& orig) = delete;
+  IDState(IDState&& rhs): subnet(rhs.subnet), origRemote(rhs.origRemote), origDest(rhs.origDest), hopRemote(rhs.hopRemote), hopLocal(rhs.hopLocal), qname(std::move(rhs.qname)), sentTime(rhs.sentTime), dnsCryptQuery(std::move(rhs.dnsCryptQuery)), packetCache(std::move(rhs.packetCache)), qTag(std::move(rhs.qTag)), tempFailureTTL(rhs.tempFailureTTL), cs(rhs.cs), du(std::move(rhs.du)), cacheKey(rhs.cacheKey), cacheKeyNoECS(rhs.cacheKeyNoECS), origFD(rhs.origFD), delayMsec(rhs.delayMsec), qtype(rhs.qtype), qclass(rhs.qclass), origID(rhs.origID), origFlags(rhs.origFlags), cacheFlags(rhs.cacheFlags), protocol(rhs.protocol), ednsAdded(rhs.ednsAdded), ecsAdded(rhs.ecsAdded), skipCache(rhs.skipCache), destHarvested(rhs.destHarvested), dnssecOK(rhs.dnssecOK), useZeroScope(rhs.useZeroScope)
+  {
+    if (rhs.isInUse()) {
+      throw std::runtime_error("Trying to move an in-use IDState");
+    }
+
+    uniqueId = std::move(rhs.uniqueId);
+#ifdef __SANITIZE_THREAD__
+    age.store(rhs.age.load());
+#else
+    age = rhs.age;
+#endif
+  }
+
+  IDState& operator=(IDState&& rhs)
+  {
+    if (isInUse()) {
+      throw std::runtime_error("Trying to overwrite an in-use IDState");
+    }
+
+    if (rhs.isInUse()) {
+      throw std::runtime_error("Trying to move an in-use IDState");
+    }
+
+    subnet = std::move(rhs.subnet);
+    origRemote = rhs.origRemote;
+    origDest = rhs.origDest;
+    hopRemote = rhs.hopRemote;
+    hopLocal = rhs.hopLocal;
+    qname = std::move(rhs.qname);
+    sentTime = rhs.sentTime;
+    dnsCryptQuery = std::move(rhs.dnsCryptQuery);
+    packetCache = std::move(rhs.packetCache);
+    qTag = std::move(rhs.qTag);
+    tempFailureTTL = std::move(rhs.tempFailureTTL);
+    cs = rhs.cs;
+    du = std::move(rhs.du);
+    cacheKey = rhs.cacheKey;
+    cacheKeyNoECS = rhs.cacheKeyNoECS;
+    origFD = rhs.origFD;
+    delayMsec = rhs.delayMsec;
+#ifdef __SANITIZE_THREAD__
+    age.store(rhs.age.load());
+#else
+    age = rhs.age;
+#endif
+    qtype = rhs.qtype;
+    qclass = rhs.qclass;
+    origID = rhs.origID;
+    origFlags = rhs.origFlags;
+    cacheFlags = rhs.cacheFlags;
+    protocol = rhs.protocol;
+    uniqueId = std::move(rhs.uniqueId);
+    ednsAdded = rhs.ednsAdded;
+    ecsAdded = rhs.ecsAdded;
+    skipCache = rhs.skipCache;
+    destHarvested = rhs.destHarvested;
+    dnssecOK = rhs.dnssecOK;
+    useZeroScope = rhs.useZeroScope;
+
+    return *this;
+  }
+
+  static const int64_t unusedIndicator = -1;
+
+  static bool isInUse(int64_t usageIndicator)
+  {
+    return usageIndicator != unusedIndicator;
+  }
+
+  bool isInUse() const
+  {
+    return usageIndicator != unusedIndicator;
+  }
+
+  /* return true if the value has been successfully replaced meaning that
+     no-one updated the usage indicator in the meantime */
+  bool tryMarkUnused(int64_t expectedUsageIndicator)
+  {
+    return usageIndicator.compare_exchange_strong(expectedUsageIndicator, unusedIndicator);
+  }
+
+  /* mark as used no matter what, return true if the state was in use before */
+  bool markAsUsed()
+  {
+    auto currentGeneration = generation++;
+    return markAsUsed(currentGeneration);
+  }
+
+  /* mark as used no matter what, return true if the state was in use before */
+  bool markAsUsed(int64_t currentGeneration)
+  {
+    int64_t oldUsage = usageIndicator.exchange(currentGeneration);
+    return oldUsage != unusedIndicator;
+  }
+
+  /* We use this value to detect whether this state is in use.
+     For performance reasons we don't want to use a lock here, but that means
+     we need to be very careful when modifying this value. Modifications happen
+     from:
+     - one of the UDP or DoH 'client' threads receiving a query, selecting a backend
+       then picking one of the states associated to this backend (via the idOffset).
+       Most of the time this state should not be in use and usageIndicator is -1, but we
+       might not yet have received a response for the query previously associated to this
+       state, meaning that we will 'reuse' this state and erase the existing state.
+       If we ever receive a response for this state, it will be discarded. This is
+       mostly fine for UDP except that we still need to be careful in order to miss
+       the 'outstanding' counters, which should only be increased when we are picking
+       an empty state, and not when reusing ;
+       For DoH, though, we have dynamically allocated a DOHUnit object that needs to
+       be freed, as well as internal objects internals to libh2o.
+     - one of the UDP receiver threads receiving a response from a backend, picking
+       the corresponding state and sending the response to the client ;
+     - the 'healthcheck' thread scanning the states to actively discover timeouts,
+       mostly to keep some counters like the 'outstanding' one sane.
+     We previously based that logic on the origFD (FD on which the query was received,
+     and therefore from where the response should be sent) but this suffered from an
+     ABA problem since it was quite likely that a UDP 'client thread' would reset it to the
+     same value since we only have so much incoming sockets:
+     - 1/ 'client' thread gets a query and set origFD to its FD, say 5 ;
+     - 2/ 'receiver' thread gets a response, read the value of origFD to 5, check that the qname,
+       qtype and qclass match
+     - 3/ during that time the 'client' thread reuses the state, setting again origFD to 5 ;
+     - 4/ the 'receiver' thread uses compare_exchange_strong() to only replace the value if it's still
+       5, except it's not the same 5 anymore and it overrides a fresh state.
+     We now use a 32-bit unsigned counter instead, which is incremented every time the state is set,
+     wrapping around if necessary, and we set an atomic signed 64-bit value, so that we still have -1
+     when the state is unused and the value of our counter otherwise.
+  */
+  boost::optional<Netmask> subnet{boost::none};               // 40
+  ComboAddress origRemote;                                    // 28
+  ComboAddress origDest;                                      // 28
+  ComboAddress hopRemote;
+  ComboAddress hopLocal;
+  DNSName qname;                                              // 24
+  StopWatch sentTime;                                         // 16
+  std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};      // 16
+  std::shared_ptr<DNSDistPacketCache> packetCache{nullptr};   // 16
+  std::shared_ptr<QTag> qTag{nullptr};                        // 16
+  boost::optional<uint32_t> tempFailureTTL;                   // 8
+  const ClientState* cs{nullptr};                             // 8
+  DOHUnit* du{nullptr};                                       // 8
+  std::atomic<int64_t> usageIndicator{unusedIndicator};  // set to unusedIndicator to indicate this state is empty   // 8
+  std::atomic<uint32_t> generation{0}; // increased every time a state is used, to be able to detect an ABA issue    // 4
+  uint32_t cacheKey{0};                                       // 4
+  uint32_t cacheKeyNoECS{0};                                  // 4
+  int origFD{-1};                                             // 4
+  int delayMsec{0};
+#ifdef __SANITIZE_THREAD__
+  std::atomic<uint16_t> age{0};
+#else
+  uint16_t age{0};                                            // 2
+#endif
+  uint16_t qtype{0};                                          // 2
+  uint16_t qclass{0};                                         // 2
+  uint16_t origID{0};                                         // 2
+  uint16_t origFlags{0};                                      // 2
+  uint16_t cacheFlags{0}; // DNS flags as sent to the backend // 2
+  dnsdist::Protocol protocol;                                 // 1
+  boost::optional<boost::uuids::uuid> uniqueId{boost::none};  // 17 (placed here to reduce the space lost to padding)
+  bool ednsAdded{false};
+  bool ecsAdded{false};
+  bool skipCache{false};
+  bool destHarvested{false}; // if true, origDest holds the original dest addr, otherwise the listening addr
+  bool dnssecOK{false};
+  bool useZeroScope{false};
+};
index b5c070396729512173cc302a5de5460c4e4b649d..12df54ae93c1fe33b0f38a75cef9eee583d505a4 100644 (file)
@@ -1224,23 +1224,23 @@ private:
   bool d_hasV6;
 };
 
-static DnstapMessage::ProtocolType ProtocolToDNSTap(DNSQuestion::Protocol protocol)
+static DnstapMessage::ProtocolType ProtocolToDNSTap(dnsdist::Protocol protocol)
 {
   DnstapMessage::ProtocolType result;
   switch (protocol) {
   default:
-  case DNSQuestion::Protocol::DoUDP:
-  case DNSQuestion::Protocol::DNSCryptUDP:
+  case dnsdist::Protocol::DoUDP:
+  case dnsdist::Protocol::DNSCryptUDP:
     result = DnstapMessage::ProtocolType::DoUDP;
     break;
-  case DNSQuestion::Protocol::DoTCP:
-  case DNSQuestion::Protocol::DNSCryptTCP:
+  case dnsdist::Protocol::DoTCP:
+  case dnsdist::Protocol::DNSCryptTCP:
     result = DnstapMessage::ProtocolType::DoTCP;
     break;
-  case DNSQuestion::Protocol::DoT:
+  case dnsdist::Protocol::DoT:
     result = DnstapMessage::ProtocolType::DoT;
     break;
-  case DNSQuestion::Protocol::DoH:
+  case dnsdist::Protocol::DoH:
     result = DnstapMessage::ProtocolType::DoH;
     break;
   }
index f213335e268f03f1cf313d4bb7192e2d85d63ce8..3ddfa4415728828912b052af9c4d1f4e467d344b 100644 (file)
@@ -74,7 +74,7 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
     });
 
   luaCtx.registerFunction<std::string (DNSQuestion::*)()const>("getProtocol", [](const DNSQuestion& dq) {
-    return DNSQuestion::ProtocolToString(dq.getProtocol());
+    return dnsdist::ProtocolToString(dq.getProtocol());
   });
 
   luaCtx.registerFunction<void(DNSQuestion::*)(std::string)>("sendTrap", [](const DNSQuestion& dq, boost::optional<std::string> reason) {
@@ -252,7 +252,7 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
     });
 
   luaCtx.registerFunction<std::string (DNSResponse::*)()const>("getProtocol", [](const DNSResponse& dr) {
-    return DNSQuestion::ProtocolToString(dr.getProtocol());
+    return dnsdist::ProtocolToString(dr.getProtocol());
   });
 
   luaCtx.registerFunction<void(DNSResponse::*)(std::string)>("sendTrap", [](const DNSResponse& dr, boost::optional<std::string> reason) {
index b1d6a38b33f59ab4ead1a0027e20575931085d21..4c9024c8bbe690674f080bbf8db95ee48bbb0b39 100644 (file)
@@ -23,6 +23,7 @@
 #include "dnsdist-lua.hh"
 #include "dnsdist-dynblocks.hh"
 #include "dnsdist-rings.hh"
+#include "dnsdist-tcp.hh"
 
 #include "statnode.hh"
 
@@ -597,9 +598,6 @@ void setupLuaInspection(LuaContext& luaCtx)
       ret << (fmt % g_tcpclientthreads->getThreadsCount() % (g_maxTCPClientThreads ? *g_maxTCPClientThreads : 0) % g_tcpclientthreads->getQueuedCount() % g_maxTCPQueuedConnections) << endl;
       ret << endl;
 
-      ret << "Query distribution mode is: " << std::string(g_useTCPSinglePipe ? "single queue" : "per-thread queues") << endl;
-      ret << endl;
-
       ret << "Frontends:" << endl;
       fmt = boost::format("%-3d %-20.20s %-20d %-20d %-20d %-25d %-20d %-20d %-20d %-20f %-20f %-20d %-20d %-25d %-25d %-15d %-15d %-15d %-15d %-15d");
       ret << (fmt % "#" % "Address" % "Connections" % "Max concurrent conn" % "Died reading query" % "Died sending response" % "Gave up" % "Client timeouts" % "Downstream timeouts" % "Avg queries/conn" % "Avg duration" % "TLS new sessions" % "TLS Resumptions" % "TLS unknown ticket keys" % "TLS inactive ticket keys" % "TLS 1.0" % "TLS 1.1" % "TLS 1.2" % "TLS 1.3" % "TLS other") << endl;
index 9c9bec0afae774277de1bb1f86a19ca6084d881f..12e44a042c88e1344340e649ce959370d0e59c21 100644 (file)
@@ -444,7 +444,7 @@ void setupLuaRules(LuaContext& luaCtx)
       sw.start();
       for(int n=0; n < times; ++n) {
         item& i = items[n % items.size()];
-        DNSQuestion dq(&i.qname, i.qtype, i.qclass, &i.rem, &i.rem, i.packet, DNSQuestion::Protocol::DoUDP, &sw.d_start);
+        DNSQuestion dq(&i.qname, i.qtype, i.qclass, &i.rem, &i.rem, i.packet, dnsdist::Protocol::DoUDP, &sw.d_start);
         if (rule->matches(&dq)) {
           matches++;
         }
index 879b785e5a210480e6ae7f5b96ea63db85d3cd37..b1f0925fc78a6afec41f19d43b9d9cccc8fe4dbb 100644 (file)
@@ -1843,15 +1843,6 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       g_hashperturb = pertub;
     });
 
-  luaCtx.writeFunction("setTCPUseSinglePipe", [](bool flag) {
-      if (g_configurationDone) {
-        g_outputBuffer="setTCPUseSinglePipe() cannot be used at runtime!\n";
-        return;
-      }
-      setLuaSideEffect();
-      g_useTCPSinglePipe = flag;
-    });
-
   luaCtx.writeFunction("setTCPInternalPipeBufferSize", [](size_t size) { g_tcpInternalPipeBufferSize = size; });
 
   luaCtx.writeFunction("snmpAgent", [client,configCheck](bool enableTraps, boost::optional<std::string> daemonSocket) {
index 8e91640096fe03b740a97ab26d0c757ff6e9137b..e2e6c04106d5b408ed12560a266adb071c75d86b 100644 (file)
@@ -124,7 +124,7 @@ void DNSDistProtoBufMessage::serialize(std::string& data) const
     m.setTime(ts.tv_sec, ts.tv_nsec / 1000);
   }
 
-  m.setRequest(d_dq.uniqueId ? *d_dq.uniqueId : getUniqueID(), d_requestor ? *d_requestor : *d_dq.remote, d_responder ? *d_responder : *d_dq.local, d_question ? d_question->d_name : *d_dq.qname, d_question ? d_question->d_type : d_dq.qtype, d_question ? d_question->d_class : d_dq.qclass, d_dq.getHeader()->id, (d_dq.getProtocol() == DNSQuestion::Protocol::DoH) ? true : d_dq.overTCP(), d_bytes ? *d_bytes : d_dq.getData().size());
+  m.setRequest(d_dq.uniqueId ? *d_dq.uniqueId : getUniqueID(), d_requestor ? *d_requestor : *d_dq.remote, d_responder ? *d_responder : *d_dq.local, d_question ? d_question->d_name : *d_dq.qname, d_question ? d_question->d_type : d_dq.qtype, d_question ? d_question->d_class : d_dq.qclass, d_dq.getHeader()->id, (d_dq.getProtocol() == dnsdist::Protocol::DoH) ? true : d_dq.overTCP(), d_bytes ? *d_bytes : d_dq.getData().size());
 
   if (d_serverIdentity) {
     m.setServerIdentity(*d_serverIdentity);
diff --git a/pdns/dnsdist-protocols.hh b/pdns/dnsdist-protocols.hh
new file mode 100644 (file)
index 0000000..271c241
--- /dev/null
@@ -0,0 +1,31 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include <vector>
+#include <string>
+
+namespace dnsdist {
+  enum class Protocol : uint8_t { DoUDP, DoTCP, DNSCryptUDP, DNSCryptTCP, DoT, DoH };
+
+  const std::string& ProtocolToString(Protocol proto);
+}
index 0a7da953723aa92a36268dc2d28f0e8ea22fcc78..cea13fc631cdac35049d730e20bbd710d846e5bd 100644 (file)
@@ -28,6 +28,7 @@
 #include "dnsdist-ecs.hh"
 #include "dnsdist-proxy-protocol.hh"
 #include "dnsdist-rings.hh"
+#include "dnsdist-tcp.hh"
 #include "dnsdist-tcp-downstream.hh"
 #include "dnsdist-tcp-upstream.hh"
 #include "dnsdist-xpf.hh"
@@ -72,7 +73,6 @@ uint64_t g_maxTCPQueuedConnections{1000};
 uint16_t g_downstreamTCPCleanupInterval{60};
 int g_tcpRecvTimeout{2};
 int g_tcpSendTimeout{2};
-bool g_useTCPSinglePipe{false};
 std::atomic<uint64_t> g_tcpStatesDumpRequested{0};
 
 class DownstreamConnectionsManager
@@ -110,7 +110,7 @@ public:
       }
     }
 
-    return std::make_shared<TCPConnectionToBackend>(ds, now);
+    return std::make_shared<TCPConnectionToBackend>(ds, mplexer, now);
   }
 
   static void releaseDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>&& conn)
@@ -248,100 +248,80 @@ std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstrea
   return downstream;
 }
 
-static void tcpClientThread(int pipefd);
+static void tcpClientThread(int pipefd, int crossProtocolPipeFD);
 
-TCPClientCollection::TCPClientCollection(size_t maxThreads, bool useSinglePipe): d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads), d_singlePipe{-1,-1}, d_useSinglePipe(useSinglePipe)
+TCPClientCollection::TCPClientCollection(size_t maxThreads): d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads)
 {
-  if (d_useSinglePipe) {
-    if (pipe(d_singlePipe) < 0) {
-      int err = errno;
-      throw std::runtime_error("Error creating the TCP single communication pipe: " + stringerror(err));
-    }
-
-    if (!setNonBlocking(d_singlePipe[0])) {
-      int err = errno;
-      close(d_singlePipe[0]);
-      close(d_singlePipe[1]);
-      throw std::runtime_error("Error setting the TCP single communication pipe non-blocking: " + stringerror(err));
-    }
-
-    if (!setNonBlocking(d_singlePipe[1])) {
-      int err = errno;
-      close(d_singlePipe[0]);
-      close(d_singlePipe[1]);
-      throw std::runtime_error("Error setting the TCP single communication pipe non-blocking: " + stringerror(err));
-    }
-
-    if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(d_singlePipe[0]) < g_tcpInternalPipeBufferSize) {
-      setPipeBufferSize(d_singlePipe[0], g_tcpInternalPipeBufferSize);
-    }
-  }
 }
 
 void TCPClientCollection::addTCPClientThread()
 {
-  int pipefds[2] = { -1, -1};
-
-  vinfolog("Adding TCP Client thread");
-
-  if (d_useSinglePipe) {
-    pipefds[0] = d_singlePipe[0];
-    pipefds[1] = d_singlePipe[1];
-  }
-  else {
-    if (pipe(pipefds) < 0) {
-      errlog("Error creating the TCP thread communication pipe: %s", stringerror());
-      return;
+  auto preparePipe = [](int fds[2], const std::string& type) -> bool {
+    if (pipe(fds) < 0) {
+      errlog("Error creating the TCP thread %s pipe: %s", type, stringerror());
+      return false;
     }
 
-    if (!setNonBlocking(pipefds[0])) {
+    if (!setNonBlocking(fds[0])) {
       int err = errno;
-      close(pipefds[0]);
-      close(pipefds[1]);
-      errlog("Error setting the TCP thread communication pipe non-blocking: %s", stringerror(err));
-      return;
+      close(fds[0]);
+      close(fds[1]);
+      errlog("Error setting the TCP thread %s pipe non-blocking: %s", type, stringerror(err));
+      return false;
     }
 
-    if (!setNonBlocking(pipefds[1])) {
+    if (!setNonBlocking(fds[1])) {
       int err = errno;
-      close(pipefds[0]);
-      close(pipefds[1]);
-      errlog("Error setting the TCP thread communication pipe non-blocking: %s", stringerror(err));
-      return;
+      close(fds[0]);
+      close(fds[1]);
+      errlog("Error setting the TCP thread %s pipe non-blocking: %s", type, stringerror(err));
+      return false;
     }
 
-    if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(pipefds[0]) < g_tcpInternalPipeBufferSize) {
-      setPipeBufferSize(pipefds[0], g_tcpInternalPipeBufferSize);
+    if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(fds[0]) < g_tcpInternalPipeBufferSize) {
+      setPipeBufferSize(fds[0], g_tcpInternalPipeBufferSize);
     }
+
+    return true;
+  };
+
+  int pipefds[2] = { -1, -1};
+  if (!preparePipe(pipefds, "communication")) {
+    return;
   }
 
+  int crossProtocolFDs[2] = { -1, -1};
+  if (!preparePipe(crossProtocolFDs, "cross-protocol")) {
+    return;
+  }
+
+  vinfolog("Adding TCP Client thread");
+
   {
     std::lock_guard<std::mutex> lock(d_mutex);
 
     if (d_numthreads >= d_tcpclientthreads.size()) {
       vinfolog("Adding a new TCP client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of TCP client threads with setMaxTCPClientThreads() in the configuration.", d_numthreads.load(), d_tcpclientthreads.size());
-      if (!d_useSinglePipe) {
-        close(pipefds[0]);
-        close(pipefds[1]);
-      }
+      close(pipefds[0]);
+      close(pipefds[1]);
       return;
     }
 
+    /* from now on this side of the pipe will be managed by that object,
+       no need to worry about it */
+    TCPWorkerThread worker(pipefds[1], crossProtocolFDs[1]);
     try {
-      std::thread t1(tcpClientThread, pipefds[0]);
+      std::thread t1(tcpClientThread, pipefds[0], crossProtocolFDs[0]);
       t1.detach();
     }
     catch (const std::runtime_error& e) {
       /* the thread creation failed, don't leak */
       errlog("Error creating a TCP thread: %s", e.what());
-      if (!d_useSinglePipe) {
-        close(pipefds[0]);
-        close(pipefds[1]);
-      }
+      close(pipefds[0]);
       return;
     }
 
-    d_tcpclientthreads.at(d_numthreads) = pipefds[1];
+    d_tcpclientthreads.at(d_numthreads) = std::move(worker);
     ++d_numthreads;
   }
 }
@@ -369,7 +349,7 @@ static IOState sendQueuedResponses(std::shared_ptr<IncomingTCPConnectionState>&
 
 static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, const TCPResponse& currentResponse)
 {
-  if (state->d_isXFR || currentResponse.d_idstate.qtype == QType::AXFR || currentResponse.d_idstate.qtype == QType::IXFR) {
+  if (currentResponse.d_idstate.qtype == QType::AXFR || currentResponse.d_idstate.qtype == QType::IXFR) {
     return;
   }
 
@@ -399,6 +379,16 @@ static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& stat
   }
 }
 
+static void prependSizeToTCPQuery(PacketBuffer& buffer)
+{
+  uint16_t queryLen = buffer.size();
+  const uint8_t sizeBytes[] = { static_cast<uint8_t>(queryLen / 256), static_cast<uint8_t>(queryLen % 256) };
+  /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
+     that could occur if we had to deal with the size during the processing,
+     especially alignment issues */
+  buffer.insert(buffer.begin(), sizeBytes, sizeBytes + 2);
+}
+
 bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now)
 {
   if (d_hadErrors) {
@@ -406,11 +396,6 @@ bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now)
     return false;
   }
 
-  if (d_isXFR) {
-    DEBUGLOG("not accepting new queries because used for XFR");
-    return false;
-  }
-
   if (d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) {
     DEBUGLOG("not accepting new queries because we already have "<<d_currentQueriesCount<<" out of "<<d_ci.cs->d_maxInFlightQueriesPerConn);
     return false;
@@ -434,9 +419,6 @@ void IncomingTCPConnectionState::resetForNewQuery()
   d_buffer.resize(sizeof(uint16_t));
   d_currentPos = 0;
   d_querySize = 0;
-  d_xfrMasterSerial = 0;
-  d_xfrSerialCount = 0;
-  d_xfrMasterSerialCount = 0;
   d_state = State::waitingForQuery;
 }
 
@@ -548,8 +530,10 @@ void IncomingTCPConnectionState::queueResponse(std::shared_ptr<IncomingTCPConnec
 }
 
 /* called from the backend code when a new response has been received */
-void IncomingTCPConnectionState::handleResponse(std::shared_ptr<IncomingTCPConnectionState> state, const struct timeval& now, TCPResponse&& response)
+void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPResponse&& response)
 {
+  std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
+
   if (response.d_connection && response.d_connection->isIdle()) {
     // if we have added a TCP Proxy Protocol payload to a connection, don't release it to the general pool yet, no one else will be able to use it anyway
     if (response.d_connection->canBeReused()) {
@@ -680,12 +664,12 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
   uint16_t qtype, qclass;
   unsigned int qnameWireLength = 0;
   DNSName qname(reinterpret_cast<const char*>(state->d_buffer.data()), state->d_buffer.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength);
-  DNSQuestion::Protocol protocol = DNSQuestion::Protocol::DoTCP;
+  dnsdist::Protocol protocol = dnsdist::Protocol::DoTCP;
   if (dnsCryptQuery) {
-    protocol = DNSQuestion::Protocol::DNSCryptTCP;
+    protocol = dnsdist::Protocol::DNSCryptTCP;
   }
   else if (state->d_handler.isTLS()) {
-    protocol = DNSQuestion::Protocol::DoT;
+    protocol = dnsdist::Protocol::DoT;
   }
 
   DNSQuestion dq(&qname, qtype, qclass, &state->d_proxiedDestination, &state->d_proxiedRemote, state->d_buffer, protocol, &queryRealTime);
@@ -697,8 +681,7 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
     dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(*state->d_proxyProtocolValues);
   }
 
-  state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
-  if (state->d_isXFR) {
+  if (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR) {
     dq.skipCache = true;
   }
 
@@ -731,15 +714,9 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
   setIDStateFromDNSQuestion(ids, dq, std::move(qname));
   ids.origID = ntohs(dh->id);
 
-  uint16_t queryLen = state->d_buffer.size();
-  const uint8_t sizeBytes[] = { static_cast<uint8_t>(queryLen / 256), static_cast<uint8_t>(queryLen % 256) };
-  /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
-     that could occur if we had to deal with the size during the processing,
-     especially alignment issues */
-  state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2);
+  prependSizeToTCPQuery(state->d_buffer);
 
   auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now);
-  downstreamConnection->assignToClientConnection(state, state->d_isXFR);
 
   bool proxyProtocolPayloadAdded = false;
   std::string proxyProtocolPayload;
@@ -772,7 +749,8 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
 
   ++state->d_currentQueriesCount;
   vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), query.d_buffer.size(), ds->getName());
-  downstreamConnection->queueQuery(std::move(query), downstreamConnection);
+  std::shared_ptr<TCPQuerySender> incoming = state;
+  downstreamConnection->queueQuery(incoming, std::move(query));
 }
 
 void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
@@ -1034,8 +1012,10 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionS
   while ((iostate == IOState::NeedRead || iostate == IOState::NeedWrite) && !state->d_lastIOBlocked);
 }
 
-void IncomingTCPConnectionState::notifyIOError(std::shared_ptr<IncomingTCPConnectionState>& state, IDState&& query, const struct timeval& now)
+void IncomingTCPConnectionState::notifyIOError(IDState&& query, const struct timeval& now)
 {
+  std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
+
   --state->d_currentQueriesCount;
   state->d_hadErrors = true;
 
@@ -1062,8 +1042,9 @@ void IncomingTCPConnectionState::notifyIOError(std::shared_ptr<IncomingTCPConnec
   }
 }
 
-void IncomingTCPConnectionState::handleXFRResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response)
+void IncomingTCPConnectionState::handleXFRResponse(const struct timeval& now, TCPResponse&& response)
 {
+  std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
   queueResponse(state, now, std::move(response));
 }
 
@@ -1124,14 +1105,56 @@ static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param
 
     IncomingTCPConnectionState::handleIO(state, now);
   }
-  catch(...) {
+  catch (...) {
     delete citmp;
     citmp = nullptr;
     throw;
   }
 }
 
-static void tcpClientThread(int pipefd)
+static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param)
+{
+  auto threadData = boost::any_cast<TCPClientThreadData*>(param);
+  CrossProtocolQuery* tmp{nullptr};
+
+  ssize_t got = read(pipefd, &tmp, sizeof(tmp));
+  if (got == 0) {
+    throw std::runtime_error("EOF while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+  }
+  else if (got == -1) {
+    if (errno == EAGAIN || errno == EINTR) {
+      return;
+    }
+    throw std::runtime_error("Error while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror());
+  }
+  else if (got != sizeof(tmp)) {
+    throw std::runtime_error("Partial read while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+  }
+
+  try {
+    struct timeval now;
+    gettimeofday(&now, nullptr);
+
+    auto query = std::move(tmp->query);
+    auto downstreamServer = std::move(tmp->downstream);
+    std::shared_ptr<TCPQuerySender> tqs = tmp->getTCPQuerySender();
+    delete tmp;
+    tmp = nullptr;
+
+    auto downstream = DownstreamConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now);
+
+#warning FIXME: what if a proxy protocol payload was inserted?
+    prependSizeToTCPQuery(query.d_buffer);
+    downstream->queueQuery(tqs, std::move(query));
+  }
+  catch (...) {
+    delete tmp;
+    tmp = nullptr;
+    throw;
+  }
+}
+
+static void tcpClientThread(int pipefd, int crossProtocolPipeFD)
 {
   /* we get launched with a pipe on which we receive file descriptors from clients that we own
      from that point on */
@@ -1141,6 +1164,8 @@ static void tcpClientThread(int pipefd)
   TCPClientThreadData data;
 
   data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data);
+  data.mplexer->addReadFD(crossProtocolPipeFD, handleCrossProtocolQuery, &data);
+
   struct timeval now;
   gettimeofday(&now, nullptr);
   time_t lastTCPCleanup = now.tv_sec;
@@ -1238,7 +1263,6 @@ void tcpAcceptorThread(ClientState* cs)
 
   auto acl = g_ACL.getLocal();
   for(;;) {
-    bool queuedCounterIncremented = false;
     std::unique_ptr<ConnectionInfo> ci;
     tcpClientCountIncremented = false;
     try {
@@ -1294,23 +1318,7 @@ void tcpAcceptorThread(ClientState* cs)
       vinfolog("Got TCP connection from %s", remote.toStringWithPort());
 
       ci->remote = remote;
-      int pipe = g_tcpclientthreads->getThread();
-      if (pipe >= 0) {
-        queuedCounterIncremented = true;
-        auto tmp = ci.release();
-        try {
-          // throws on failure
-          writen2WithTimeout(pipe, &tmp, sizeof(tmp), timeval{0,0});
-        }
-        catch (...) {
-          delete tmp;
-          tmp = nullptr;
-          throw;
-        }
-      }
-      else {
-        g_tcpclientthreads->decrementQueuedCount();
-        queuedCounterIncremented = false;
+      if (!g_tcpclientthreads->passConnectionToThread(std::move(ci))) {
         if (tcpClientCountIncremented) {
           decrementTCPClientCount(remote);
         }
@@ -1321,9 +1329,6 @@ void tcpAcceptorThread(ClientState* cs)
       if (tcpClientCountIncremented) {
         decrementTCPClientCount(remote);
       }
-      if (queuedCounterIncremented) {
-        g_tcpclientthreads->decrementQueuedCount();
-      }
     }
     catch (...){}
   }
index cf96c599ad06a313bbe803d4df735c8b118babf1..f71ecb878a8645dab7e3753e5aa607eb77c04d75 100644 (file)
@@ -54,6 +54,7 @@
 #include "dnsdist-proxy-protocol.hh"
 #include "dnsdist-rings.hh"
 #include "dnsdist-secpoll.hh"
+#include "dnsdist-tcp.hh"
 #include "dnsdist-web.hh"
 #include "dnsdist-xpf.hh"
 
@@ -78,8 +79,8 @@
 
 /* the RuleAction plan
    Set of Rules, if one matches, it leads to an Action
-   Both rules and actions could conceivably be Lua based. 
-   On the C++ side, both could be inherited from a class Rule and a class Action, 
+   Both rules and actions could conceivably be Lua based.
+   On the C++ side, both could be inherited from a class Rule and a class Action,
    on the Lua side we can't do that. */
 
 using std::thread;
@@ -107,7 +108,7 @@ GlobalStateHolder<pools_t> g_pools;
 size_t g_udpVectorSize{1};
 
 /* UDP: the grand design. Per socket we listen on for incoming queries there is one thread.
-   Then we have a bunch of connected sockets for talking to downstream servers. 
+   Then we have a bunch of connected sockets for talking to downstream servers.
    We send directly to those sockets.
 
    For the return path, per downstream server we have a thread that listens to responses.
@@ -115,7 +116,7 @@ size_t g_udpVectorSize{1};
    Per socket there is an array of 2^16 states, when we send out a packet downstream, we note
    there the original requestor and the original id. The new ID is the offset in the array.
 
-   When an answer comes in on a socket, we look up the offset by the id, and lob it to the 
+   When an answer comes in on a socket, we look up the offset by the id, and lob it to the
    original requestor.
 
    IDs are assigned by atomic increments of the socket offset.
@@ -633,6 +634,16 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
 
         dh->id = ids->origID;
 
+        /* don't call processResponse on a truncated answer for DoH, we will retry over TCP */
+        if (du && dh->tc) {
+#ifdef HAVE_DNS_OVER_HTTPS
+          // DoH query
+          cerr<<"truncated answer for DoH"<<endl;
+          du->handleUDPResponse(std::move(response), std::move(*ids));
+#endif
+          continue;
+        }
+
         DNSResponse dr = makeDNSResponseFromIDState(*ids, response);
         if (dh->tc && g_truncateTC) {
           truncateTC(response, dr.getMaximumSize(), qnameWireLength);
@@ -647,27 +658,11 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
           if (du) {
 #ifdef HAVE_DNS_OVER_HTTPS
             // DoH query
-            du->response = std::move(response);
-            static_assert(sizeof(du) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
-            ssize_t sent = write(du->rsock, &du, sizeof(du));
-            if (sent != sizeof(du)) {
-              if (errno == EAGAIN || errno == EWOULDBLOCK) {
-                ++g_stats.dohResponsePipeFull;
-                vinfolog("Unable to pass a DoH response to the DoH worker thread because the pipe is full");
-              }
-              else {
-                vinfolog("Unable to pass a DoH response to the DoH worker thread because we couldn't write to the pipe: %s", stringerror());
-              }
-
-              /* at this point we have the only remaining pointer on this
-                 DOHUnit object since we did set ids->du to nullptr earlier,
-                 except if we got the response before the pointer could be
-                 released by the frontend */
-              du->release();
-            }
-#endif /* HAVE_DNS_OVER_HTTPS */
+            du->handleUDPResponse(std::move(response), IDState());
+#endif
             du = nullptr;
           }
+
           else {
             ComboAddress empty;
             empty.sin4.sin_family = 0;
@@ -889,7 +884,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru
       case DNSAction::Action::Refused:
         vinfolog("Query from %s refused because of dynamic block", dq.remote->toStringWithPort());
         updateBlockStats();
-      
+
         dq.getHeader()->rcode = RCode::Refused;
         dq.getHeader()->qr = true;
         return true;
@@ -954,7 +949,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru
       case DNSAction::Action::Truncate:
         if (!dq.overTCP()) {
           updateBlockStats();
-      
+
           vinfolog("Query from %s for %s truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
           dq.getHeader()->tc = true;
           dq.getHeader()->qr = true;
@@ -1210,7 +1205,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
       // we need ECS parsing (parseECS) to be true so we can be sure that the initial incoming query did not have an existing
       // ECS option, which would make it unsuitable for the zero-scope feature.
       if (dq.packetCache && !dq.skipCache && (!selectedBackend || !selectedBackend->disableZeroScope) && dq.packetCache->isECSParsingEnabled()) {
-        if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKeyNoECS, dq.subnet, dq.dnssecOK, !dq.overTCP() || dq.getProtocol() == DNSQuestion::Protocol::DoH, allowExpired)) {
+        if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKeyNoECS, dq.subnet, dq.dnssecOK, !dq.overTCP() || dq.getProtocol() == dnsdist::Protocol::DoH, allowExpired)) {
 
           if (!prepareOutgoingResponse(holders, cs, dq, true)) {
             return ProcessQueryResult::Drop;
@@ -1232,7 +1227,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
     }
 
     if (dq.packetCache && !dq.skipCache) {
-      if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKey, dq.subnet, dq.dnssecOK, !dq.overTCP() || dq.getProtocol() == DNSQuestion::Protocol::DoH, allowExpired)) {
+      if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKey, dq.subnet, dq.dnssecOK, !dq.overTCP() || dq.getProtocol() == dnsdist::Protocol::DoH, allowExpired)) {
 
         restoreFlags(dq.getHeader(), dq.origFlags);
 
@@ -1334,7 +1329,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     uint16_t qtype, qclass;
     unsigned int qnameWireLength = 0;
     DNSName qname(reinterpret_cast<const char*>(query.data()), query.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength);
-    DNSQuestion dq(&qname, qtype, qclass, proxiedDestination.sin4.sin_family != 0 ? &proxiedDestination : &cs.local, &proxiedRemote, query, dnsCryptQuery ? DNSQuestion::Protocol::DNSCryptUDP : DNSQuestion::Protocol::DoUDP, &queryRealTime);
+    DNSQuestion dq(&qname, qtype, qclass, proxiedDestination.sin4.sin_family != 0 ? &proxiedDestination : &cs.local, &proxiedRemote, query, dnsCryptQuery ? dnsdist::Protocol::DNSCryptUDP : dnsdist::Protocol::DoUDP, &queryRealTime);
     dq.dnsCryptQuery = std::move(dnsCryptQuery);
     if (!proxyProtocolValues.empty()) {
       dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
@@ -1716,7 +1711,7 @@ static void healthChecksThread()
       dss->dropRate.store(1.0*(dss->reuseds.load() - dss->prev.reuseds.load())/delta);
       dss->prev.queries.store(dss->queries.load());
       dss->prev.reuseds.store(dss->reuseds.load());
-      
+
       for (IDState& ids  : dss->idStates) { // timeouts
         int64_t usageIndicator = ids.usageIndicator;
         if(IDState::isInUse(usageIndicator) && ids.age++ > g_udpTimeout) {
@@ -1749,7 +1744,7 @@ static void healthChecksThread()
           fake.id = ids.origID;
 
           g_rings.insertResponse(ts, ids.origRemote, ids.qname, ids.qtype, std::numeric_limits<unsigned int>::max(), 0, fake, dss->remote);
-        }          
+        }
       }
     }
 
@@ -1969,7 +1964,7 @@ static void setUpLocalBind(std::unique_ptr<ClientState>& cs)
   cs->ready = true;
 }
 
-struct 
+struct
 {
   vector<string> locals;
   vector<string> remotes;
@@ -2055,7 +2050,7 @@ int main(int argc, char** argv)
       srandom(tv.tv_sec ^ tv.tv_usec ^ getpid());
       g_hashperturb=random();
     }
-  
+
 #endif
     ComboAddress clientAddress = ComboAddress();
     g_cmdLine.config=SYSCONFDIR "/dnsdist.conf";
@@ -2403,7 +2398,7 @@ int main(int argc, char** argv)
       g_maxTCPClientThreads = 1;
     }
 
-    g_tcpclientthreads = std::unique_ptr<TCPClientCollection>(new TCPClientCollection(*g_maxTCPClientThreads, g_useTCPSinglePipe));
+    g_tcpclientthreads = std::make_unique<TCPClientCollection>(*g_maxTCPClientThreads);
 
     for (auto& t : todo) {
       t();
@@ -2481,7 +2476,7 @@ int main(int argc, char** argv)
 
     thread stattid(maintThread);
     stattid.detach();
-  
+
     thread healththread(healthChecksThread);
 
     thread dynBlockMaintThread(dynBlockMaintenanceThread);
index fadd2345e4baabd3e423918bdf7fe997719148b9..272c9498e92922adbd4353608c594dec7e121397 100644 (file)
 #include "dnsdist-cache.hh"
 #include "dnsdist-dynbpf.hh"
 #include "dnsdist-lbpolicies.hh"
+#include "dnsdist-protocols.hh"
 #include "dnsname.hh"
 #include "doh.hh"
 #include "ednsoptions.hh"
-#include "gettime.hh"
 #include "iputils.hh"
 #include "misc.hh"
 #include "mplexer.hh"
@@ -52,7 +52,6 @@
 #include "proxy-protocol.hh"
 #include "stat_t.hh"
 
-void carbonDumpThread();
 uint64_t uptimeOfProcess(const std::string& str);
 
 extern uint16_t g_ECSSourcePrefixV4;
@@ -63,14 +62,7 @@ using QTag = std::unordered_map<string, string>;
 
 struct DNSQuestion
 {
-  enum class Protocol : uint8_t { DoUDP, DoTCP, DNSCryptUDP, DNSCryptTCP, DoT, DoH };
-  static const std::string& ProtocolToString(Protocol proto)
-  {
-    static const std::vector<std::string> values = { "Do53 UDP", "Do53 TCP", "DNSCrypt UDP", "DNSCrypt TCP", "DNS over TLS", "DNS over HTTPS" };
-    return values.at(static_cast<int>(proto));
-  }
-
-  DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, PacketBuffer& data_, Protocol proto, const struct timespec* queryTime_):
+  DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, PacketBuffer& data_, dnsdist::Protocol proto, const struct timespec* queryTime_):
     data(data_), qname(name), local(lc), remote(rem), queryTime(queryTime_), tempFailureTTL(boost::none), qtype(type), qclass(class_), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), protocol(proto), ecsOverride(g_ECSOverride) {
     const uint16_t* flags = getFlagsFromDNSHeader(getHeader());
     origFlags = *flags;
@@ -119,14 +111,14 @@ struct DNSQuestion
     return 4096;
   }
 
-  Protocol getProtocol() const
+  dnsdist::Protocol getProtocol() const
   {
     return protocol;
   }
 
   bool overTCP() const
   {
-    return !(protocol == Protocol::DoUDP || protocol == Protocol::DNSCryptUDP);
+    return !(protocol == dnsdist::Protocol::DoUDP || protocol == dnsdist::Protocol::DNSCryptUDP);
   }
 
 protected:
@@ -162,7 +154,7 @@ public:
   uint16_t ecsPrefixLength;
   uint16_t origFlags;
   uint16_t cacheFlags{0}; /* DNS flags as sent to the backend */
-  const Protocol protocol;
+  const dnsdist::Protocol protocol;
   uint8_t ednsRCode{0};
   bool skipCache{false};
   bool ecsOverride;
@@ -177,7 +169,7 @@ public:
 
 struct DNSResponse : DNSQuestion
 {
-  DNSResponse(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, PacketBuffer& data_, DNSQuestion::Protocol proto, const struct timespec* queryTime_):
+  DNSResponse(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, PacketBuffer& data_, dnsdist::Protocol proto, const struct timespec* queryTime_):
     DNSQuestion(name, type, class_, lc, rem, data_, proto, queryTime_) { }
   DNSResponse(const DNSResponse&) = delete;
   DNSResponse& operator=(const DNSResponse&) = delete;
@@ -420,44 +412,7 @@ struct DNSDistStats
 extern struct DNSDistStats g_stats;
 void doLatencyStats(double udiff);
 
-
-struct StopWatch
-{
-  StopWatch(bool realTime=false): d_needRealTime(realTime)
-  {
-  }
-  struct timespec d_start{0,0};
-  bool d_needRealTime{false};
-
-  void start() {
-    if(gettime(&d_start, d_needRealTime) < 0)
-      unixDie("Getting timestamp");
-
-  }
-
-  void set(const struct timespec& from) {
-    d_start = from;
-  }
-
-  double udiff() const {
-    struct timespec now;
-    if(gettime(&now, d_needRealTime) < 0)
-      unixDie("Getting timestamp");
-
-    return 1000000.0*(now.tv_sec - d_start.tv_sec) + (now.tv_nsec - d_start.tv_nsec)/1000.0;
-  }
-
-  double udiffAndSet() {
-    struct timespec now;
-    if(gettime(&now, d_needRealTime) < 0)
-      unixDie("Getting timestamp");
-
-    auto ret= 1000000.0*(now.tv_sec - d_start.tv_sec) + (now.tv_nsec - d_start.tv_nsec)/1000.0;
-    d_start = now;
-    return ret;
-  }
-
-};
+#include "dnsdist-idstate.hh"
 
 class BasicQPSLimiter
 {
@@ -568,189 +523,6 @@ private:
   bool d_passthrough{true};
 };
 
-struct ClientState;
-
-/* g++ defines __SANITIZE_THREAD__
-   clang++ supports the nice __has_feature(thread_sanitizer),
-   let's merge them */
-#if defined(__has_feature)
-#if __has_feature(thread_sanitizer)
-#define __SANITIZE_THREAD__ 1
-#endif
-#endif
-
-struct IDState
-{
-  IDState(): sentTime(true), tempFailureTTL(boost::none) { origDest.sin4.sin_family = 0;}
-  IDState(const IDState& orig) = delete;
-  IDState(IDState&& rhs): subnet(rhs.subnet), origRemote(rhs.origRemote), origDest(rhs.origDest), hopRemote(rhs.hopRemote), hopLocal(rhs.hopLocal), qname(std::move(rhs.qname)), sentTime(rhs.sentTime), dnsCryptQuery(std::move(rhs.dnsCryptQuery)), packetCache(std::move(rhs.packetCache)), qTag(std::move(rhs.qTag)), tempFailureTTL(rhs.tempFailureTTL), cs(rhs.cs), du(std::move(rhs.du)), cacheKey(rhs.cacheKey), cacheKeyNoECS(rhs.cacheKeyNoECS), origFD(rhs.origFD), delayMsec(rhs.delayMsec), qtype(rhs.qtype), qclass(rhs.qclass), origID(rhs.origID), origFlags(rhs.origFlags), cacheFlags(rhs.cacheFlags), protocol(rhs.protocol), ednsAdded(rhs.ednsAdded), ecsAdded(rhs.ecsAdded), skipCache(rhs.skipCache), destHarvested(rhs.destHarvested), dnssecOK(rhs.dnssecOK), useZeroScope(rhs.useZeroScope)
-  {
-    if (rhs.isInUse()) {
-      throw std::runtime_error("Trying to move an in-use IDState");
-    }
-
-    uniqueId = std::move(rhs.uniqueId);
-#ifdef __SANITIZE_THREAD__
-    age.store(rhs.age.load());
-#else
-    age = rhs.age;
-#endif
-  }
-
-  IDState& operator=(IDState&& rhs)
-  {
-    if (isInUse()) {
-      throw std::runtime_error("Trying to overwrite an in-use IDState");
-    }
-
-    if (rhs.isInUse()) {
-      throw std::runtime_error("Trying to move an in-use IDState");
-    }
-
-    subnet = std::move(rhs.subnet);
-    origRemote = rhs.origRemote;
-    origDest = rhs.origDest;
-    hopRemote = rhs.hopRemote;
-    hopLocal = rhs.hopLocal;
-    qname = std::move(rhs.qname);
-    sentTime = rhs.sentTime;
-    dnsCryptQuery = std::move(rhs.dnsCryptQuery);
-    packetCache = std::move(rhs.packetCache);
-    qTag = std::move(rhs.qTag);
-    tempFailureTTL = std::move(rhs.tempFailureTTL);
-    cs = rhs.cs;
-    du = std::move(rhs.du);
-    cacheKey = rhs.cacheKey;
-    cacheKeyNoECS = rhs.cacheKeyNoECS;
-    origFD = rhs.origFD;
-    delayMsec = rhs.delayMsec;
-#ifdef __SANITIZE_THREAD__
-    age.store(rhs.age.load());
-#else
-    age = rhs.age;
-#endif
-    qtype = rhs.qtype;
-    qclass = rhs.qclass;
-    origID = rhs.origID;
-    origFlags = rhs.origFlags;
-    cacheFlags = rhs.cacheFlags;
-    protocol = rhs.protocol;
-    uniqueId = std::move(rhs.uniqueId);
-    ednsAdded = rhs.ednsAdded;
-    ecsAdded = rhs.ecsAdded;
-    skipCache = rhs.skipCache;
-    destHarvested = rhs.destHarvested;
-    dnssecOK = rhs.dnssecOK;
-    useZeroScope = rhs.useZeroScope;
-
-    return *this;
-  }
-
-  static const int64_t unusedIndicator = -1;
-
-  static bool isInUse(int64_t usageIndicator)
-  {
-    return usageIndicator != unusedIndicator;
-  }
-
-  bool isInUse() const
-  {
-    return usageIndicator != unusedIndicator;
-  }
-
-  /* return true if the value has been successfully replaced meaning that
-     no-one updated the usage indicator in the meantime */
-  bool tryMarkUnused(int64_t expectedUsageIndicator)
-  {
-    return usageIndicator.compare_exchange_strong(expectedUsageIndicator, unusedIndicator);
-  }
-
-  /* mark as used no matter what, return true if the state was in use before */
-  bool markAsUsed()
-  {
-    auto currentGeneration = generation++;
-    return markAsUsed(currentGeneration);
-  }
-
-  /* mark as used no matter what, return true if the state was in use before */
-  bool markAsUsed(int64_t currentGeneration)
-  {
-    int64_t oldUsage = usageIndicator.exchange(currentGeneration);
-    return oldUsage != unusedIndicator;
-  }
-
-  /* We use this value to detect whether this state is in use.
-     For performance reasons we don't want to use a lock here, but that means
-     we need to be very careful when modifying this value. Modifications happen
-     from:
-     - one of the UDP or DoH 'client' threads receiving a query, selecting a backend
-       then picking one of the states associated to this backend (via the idOffset).
-       Most of the time this state should not be in use and usageIndicator is -1, but we
-       might not yet have received a response for the query previously associated to this
-       state, meaning that we will 'reuse' this state and erase the existing state.
-       If we ever receive a response for this state, it will be discarded. This is
-       mostly fine for UDP except that we still need to be careful in order to miss
-       the 'outstanding' counters, which should only be increased when we are picking
-       an empty state, and not when reusing ;
-       For DoH, though, we have dynamically allocated a DOHUnit object that needs to
-       be freed, as well as internal objects internals to libh2o.
-     - one of the UDP receiver threads receiving a response from a backend, picking
-       the corresponding state and sending the response to the client ;
-     - the 'healthcheck' thread scanning the states to actively discover timeouts,
-       mostly to keep some counters like the 'outstanding' one sane.
-     We previously based that logic on the origFD (FD on which the query was received,
-     and therefore from where the response should be sent) but this suffered from an
-     ABA problem since it was quite likely that a UDP 'client thread' would reset it to the
-     same value since we only have so much incoming sockets:
-     - 1/ 'client' thread gets a query and set origFD to its FD, say 5 ;
-     - 2/ 'receiver' thread gets a response, read the value of origFD to 5, check that the qname,
-       qtype and qclass match
-     - 3/ during that time the 'client' thread reuses the state, setting again origFD to 5 ;
-     - 4/ the 'receiver' thread uses compare_exchange_strong() to only replace the value if it's still
-       5, except it's not the same 5 anymore and it overrides a fresh state.
-     We now use a 32-bit unsigned counter instead, which is incremented every time the state is set,
-     wrapping around if necessary, and we set an atomic signed 64-bit value, so that we still have -1
-     when the state is unused and the value of our counter otherwise.
-  */
-  boost::optional<Netmask> subnet{boost::none};               // 40
-  ComboAddress origRemote;                                    // 28
-  ComboAddress origDest;                                      // 28
-  ComboAddress hopRemote;
-  ComboAddress hopLocal;
-  DNSName qname;                                              // 24
-  StopWatch sentTime;                                         // 16
-  std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};      // 16
-  std::shared_ptr<DNSDistPacketCache> packetCache{nullptr};   // 16
-  std::shared_ptr<QTag> qTag{nullptr};                        // 16
-  boost::optional<uint32_t> tempFailureTTL;                   // 8
-  const ClientState* cs{nullptr};                             // 8
-  DOHUnit* du{nullptr};                                       // 8
-  std::atomic<int64_t> usageIndicator{unusedIndicator};  // set to unusedIndicator to indicate this state is empty   // 8
-  std::atomic<uint32_t> generation{0}; // increased every time a state is used, to be able to detect an ABA issue    // 4
-  uint32_t cacheKey{0};                                       // 4
-  uint32_t cacheKeyNoECS{0};                                  // 4
-  int origFD{-1};                                             // 4
-  int delayMsec{0};
-#ifdef __SANITIZE_THREAD__
-  std::atomic<uint16_t> age{0};
-#else
-  uint16_t age{0};                                            // 2
-#endif
-  uint16_t qtype{0};                                          // 2
-  uint16_t qclass{0};                                         // 2
-  uint16_t origID{0};                                         // 2
-  uint16_t origFlags{0};                                      // 2
-  uint16_t cacheFlags{0}; // DNS flags as sent to the backend // 2
-  DNSQuestion::Protocol protocol;                             // 1
-  boost::optional<boost::uuids::uuid> uniqueId{boost::none};  // 17 (placed here to reduce the space lost to padding)
-  bool ednsAdded{false};
-  bool ecsAdded{false};
-  bool skipCache{false};
-  bool destHarvested{false}; // if true, origDest holds the original dest addr, otherwise the listening addr
-  bool dnssecOK{false};
-  bool useZeroScope{false};
-};
-
 typedef std::unordered_map<string, unsigned int> QueryCountRecords;
 typedef std::function<std::tuple<bool, string>(const DNSQuestion* dq)> QueryCountFilter;
 struct QueryCount {
@@ -781,15 +553,15 @@ struct ClientState
   std::string interface;
   stat_t queries{0};
   mutable stat_t responses{0};
-  stat_t tcpDiedReadingQuery{0};
-  stat_t tcpDiedSendingResponse{0};
-  stat_t tcpGaveUp{0};
-  stat_t tcpClientTimeouts{0};
-  stat_t tcpDownstreamTimeouts{0};
+  mutable stat_t tcpDiedReadingQuery{0};
+  mutable stat_t tcpDiedSendingResponse{0};
+  mutable stat_t tcpGaveUp{0};
+  mutable stat_t tcpClientTimeouts{0};
+  mutable stat_t tcpDownstreamTimeouts{0};
   /* current number of connections to this frontend */
-  stat_t tcpCurrentConnections{0};
+  mutable stat_t tcpCurrentConnections{0};
   /* maximum number of concurrent connections to this frontend reached */
-  stat_t tcpMaxConcurrentConnections{0};
+  mutable stat_t tcpMaxConcurrentConnections{0};
   stat_t tlsNewSessions{0}; // A new TLS session has been negotiated, no resumption
   stat_t tlsResumptions{0}; // A TLS session has been resumed, either via session id or via a TLS ticket
   stat_t tlsUnknownTicketKey{0}; // A TLS ticket has been presented but we don't have the associated key (might have expired)
@@ -875,49 +647,6 @@ struct ClientState
   }
 };
 
-class TCPClientCollection {
-  std::vector<int> d_tcpclientthreads;
-  stat_t d_numthreads{0};
-  stat_t d_pos{0};
-  stat_t d_queued{0};
-  const uint64_t d_maxthreads{0};
-  std::mutex d_mutex;
-  int d_singlePipe[2];
-  const bool d_useSinglePipe;
-public:
-
-  TCPClientCollection(size_t maxThreads, bool useSinglePipe=false);
-  int getThread()
-  {
-    if (d_numthreads == 0) {
-      throw std::runtime_error("No TCP worker thread yet");
-    }
-
-    uint64_t pos = d_pos++;
-    ++d_queued;
-    return d_tcpclientthreads.at(pos % d_numthreads);
-  }
-  bool hasReachedMaxThreads() const
-  {
-    return d_numthreads >= d_maxthreads;
-  }
-  uint64_t getThreadsCount() const
-  {
-    return d_numthreads;
-  }
-  uint64_t getQueuedCount() const
-  {
-    return d_queued;
-  }
-  void decrementQueuedCount()
-  {
-    --d_queued;
-  }
-  void addTCPClientThread();
-};
-
-extern std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
-
 struct DownstreamState
 {
    typedef std::function<std::tuple<DNSName, uint16_t, uint16_t>(const DNSName&, uint16_t, uint16_t, dnsheader*)> checkfunc_t;
@@ -1269,3 +998,5 @@ void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname);
 
 int pickBackendSocketForSending(std::shared_ptr<DownstreamState>& state);
 ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const PacketBuffer& request, bool healthCheck = false);
+
+void carbonDumpThread();
index 42b16a3de596b9020cb9f9522405e8ae639d4ebb..83a579f9832b7dbcb9f839f77cf544ad0a109912 100644 (file)
@@ -140,7 +140,7 @@ dnsdist_SOURCES = \
        dnsdist-dynbpf.cc dnsdist-dynbpf.hh \
        dnsdist-ecs.cc dnsdist-ecs.hh \
        dnsdist-healthchecks.cc dnsdist-healthchecks.hh \
-       dnsdist-idstate.cc \
+       dnsdist-idstate.cc dnsdist-idstate.hh \
        dnsdist-kvs.hh dnsdist-kvs.cc \
        dnsdist-lbpolicies.cc dnsdist-lbpolicies.hh \
        dnsdist-lua-actions.cc \
@@ -160,6 +160,7 @@ dnsdist_SOURCES = \
        dnsdist-lua.cc dnsdist-lua.hh \
        dnsdist-prometheus.hh \
        dnsdist-protobuf.cc dnsdist-protobuf.hh \
+       dnsdist-protocols.cc dnsdist-protocols.hh \
        dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \
        dnsdist-rings.cc dnsdist-rings.hh \
        dnsdist-rules.cc dnsdist-rules.hh \
@@ -168,7 +169,7 @@ dnsdist_SOURCES = \
        dnsdist-systemd.cc dnsdist-systemd.hh \
        dnsdist-tcp-downstream.cc dnsdist-tcp-downstream.hh \
        dnsdist-tcp-upstream.hh \
-       dnsdist-tcp.cc \
+       dnsdist-tcp.cc dnsdist-tcp.hh \
        dnsdist-web.cc dnsdist-web.hh \
        dnsdist-xpf.cc dnsdist-xpf.hh \
        dnsdist.cc dnsdist.hh \
@@ -229,7 +230,7 @@ testrunner_SOURCES = \
        dnsdist-dynblocks.cc dnsdist-dynblocks.hh \
        dnsdist-dynbpf.cc dnsdist-dynbpf.hh \
        dnsdist-ecs.cc dnsdist-ecs.hh \
-       dnsdist-idstate.cc \
+       dnsdist-idstate.cc dnsdist-idstate.hh \
        dnsdist-kvs.cc dnsdist-kvs.hh \
        dnsdist-lbpolicies.cc dnsdist-lbpolicies.hh \
        dnsdist-lua-bindings-dnsquestion.cc \
@@ -238,11 +239,12 @@ testrunner_SOURCES = \
        dnsdist-lua-ffi-interface.h dnsdist-lua-ffi-interface.inc \
        dnsdist-lua-ffi.cc dnsdist-lua-ffi.hh \
        dnsdist-lua-vars.cc \
+       dnsdist-protocols.cc dnsdist-protocols.hh \
        dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \
        dnsdist-rings.cc dnsdist-rings.hh \
        dnsdist-rules.cc dnsdist-rules.hh \
        dnsdist-tcp-downstream.cc \
-       dnsdist-tcp.cc \
+       dnsdist-tcp.cc dnsdist-tcp.hh \
        dnsdist-xpf.cc dnsdist-xpf.hh \
        dnsdist.hh \
        dnslabeltext.cc \
diff --git a/pdns/dnsdistdist/dnsdist-idstate.hh b/pdns/dnsdistdist/dnsdist-idstate.hh
new file mode 120000 (symlink)
index 0000000..44f6de4
--- /dev/null
@@ -0,0 +1 @@
+../dnsdist-idstate.hh
\ No newline at end of file
diff --git a/pdns/dnsdistdist/dnsdist-protocols.cc b/pdns/dnsdistdist/dnsdist-protocols.cc
new file mode 100644 (file)
index 0000000..233bf4b
--- /dev/null
@@ -0,0 +1,31 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#include "dnsdist-protocols.hh"
+
+namespace dnsdist {
+  const std::string& ProtocolToString(Protocol proto)
+  {
+    static const std::vector<std::string> values = { "Do53 UDP", "Do53 TCP", "DNSCrypt UDP", "DNSCrypt TCP", "DNS over TLS", "DNS over HTTPS" };
+    return values.at(static_cast<int>(proto));
+  }
+}
+
diff --git a/pdns/dnsdistdist/dnsdist-protocols.hh b/pdns/dnsdistdist/dnsdist-protocols.hh
new file mode 120000 (symlink)
index 0000000..cb9d2fd
--- /dev/null
@@ -0,0 +1 @@
+../dnsdist-protocols.hh
\ No newline at end of file
index 7a8b28cd6e74e1208cf8c5acf3b1dab04517cf37..9623866699a5d9052998f1bb366b2f561e4c0e09 100644 (file)
@@ -4,37 +4,14 @@
 
 #include "dnsparser.hh"
 
-const uint16_t TCPConnectionToBackend::s_xfrID = 0;
-
-void TCPConnectionToBackend::assignToClientConnection(std::shared_ptr<IncomingTCPConnectionState>& clientConn, bool isXFR)
-{
-  if (d_usedForXFR == true) {
-    throw std::runtime_error("Trying to send a query over a backend connection used for XFR");
-  }
-
-  if (isXFR) {
-    d_usedForXFR = true;
-  }
-
-  if (!d_clientConn) {
-    d_clientConn = clientConn;
-    d_ioState = make_unique<IOStateHandler>(clientConn->getIOMPlexer(), d_handler->getDescriptor());
-  }
-  else if (d_clientConn != clientConn) {
-    throw std::runtime_error("Assigning a query from a different client to an existing backend connection with pending queries");
-  }
-}
-
 void TCPConnectionToBackend::release()
 {
-  if (!d_usedForXFR) {
-    d_ds->outstanding -= d_pendingResponses.size();
-  }
+  d_ds->outstanding -= d_pendingResponses.size();
 
   d_pendingResponses.clear();
   d_pendingQueries.clear();
 
-  d_clientConn.reset();
+  d_sender.reset();
   if (d_ioState) {
     d_ioState.reset();
   }
@@ -72,9 +49,7 @@ IOState TCPConnectionToBackend::sendQuery(std::shared_ptr<TCPConnectionToBackend
   conn->d_pendingResponses[conn->d_currentQuery.d_idstate.origID] = std::move(conn->d_currentQuery);
   conn->d_currentQuery.d_buffer.clear();
 
-  if (!conn->d_usedForXFR) {
-    ++conn->d_ds->outstanding;
-  }
+  ++conn->d_ds->outstanding;
 
   return state;
 }
@@ -185,20 +160,34 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& c
 
       DEBUGLOG("connection died, number of failures is "<<conn->d_downstreamFailures<<", retries is "<<conn->d_ds->retries);
 
-      if ((!conn->d_usedForXFR || conn->d_queries == 0) && conn->d_downstreamFailures < conn->d_ds->retries) {
+      if (conn->d_downstreamFailures < conn->d_ds->retries) {
 
         conn->d_ioState.reset();
         ioGuard.release();
 
         try {
           if (conn->reconnect()) {
-            conn->d_ioState = make_unique<IOStateHandler>(conn->d_clientConn->getIOMPlexer(), conn->d_handler->getDescriptor());
+            conn->d_ioState = make_unique<IOStateHandler>(conn->d_mplexer, conn->d_handler->getDescriptor());
 
             /* we need to resend the queries that were in flight, if any */
             for (auto& pending : conn->d_pendingResponses) {
-              conn->d_pendingQueries.push_back(std::move(pending.second));
-              if (!conn->d_usedForXFR) {
-                --conn->d_ds->outstanding;
+              --conn->d_ds->outstanding;
+
+              if (pending.second.isXFR() && pending.second.d_xfrStarted) {
+                /* this one can't be restarted, sorry */
+                DEBUGLOG("A XFR for which a response has already been sent cannot be restarted");
+                try {
+                  conn->d_sender->notifyIOError(std::move(pending.second.d_idstate), now);
+                }
+                catch (const std::exception& e) {
+                  vinfolog("Got an exception while notifying: %s", e.what());
+                }
+                catch (...) {
+                  vinfolog("Got exception while notifying");
+                }
+              }
+              else {
+                conn->d_pendingQueries.push_back(std::move(pending.second));
               }
             }
             conn->d_pendingResponses.clear();
@@ -278,10 +267,14 @@ void TCPConnectionToBackend::handleIOCallback(int fd, FDMultiplexer::funcparam_t
   handleIO(conn, now);
 }
 
-void TCPConnectionToBackend::queueQuery(TCPQuery&& query, std::shared_ptr<TCPConnectionToBackend>& sharedSelf)
+void TCPConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender, TCPQuery&& query)
 {
-  if (d_ioState == nullptr) {
-    throw std::runtime_error("Trying to queue a query to a TCP connection that has no incoming client connection assigned");
+  if (!d_sender) {
+    d_sender = sender;
+    d_ioState = make_unique<IOStateHandler>(d_mplexer, d_handler->getDescriptor());
+  }
+  else if (d_sender != sender) {
+    throw std::runtime_error("Assigning a query from a different client to an existing backend connection with pending queries");
   }
 
   // if we are not already sending a query or in the middle of reading a response (so idle or doingHandshake),
@@ -299,7 +292,8 @@ void TCPConnectionToBackend::queueQuery(TCPQuery&& query, std::shared_ptr<TCPCon
     struct timeval now;
     gettimeofday(&now, 0);
 
-    handleIO(sharedSelf, now);
+    auto shared = shared_from_this();
+    handleIO(shared, now);
   }
   else {
     DEBUGLOG("Adding new query to the queue because we are in state "<<(int)d_state);
@@ -404,31 +398,31 @@ void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, F
 {
   d_connectionDied = true;
 
-  auto& clientConn = d_clientConn;
-  if (!clientConn->active()) {
+  auto& sender = d_sender;
+  if (!sender->active()) {
     // a client timeout occurred, or something like that */
-    d_clientConn.reset();
+    d_sender.reset();
     return;
   }
 
   if (reason == FailureReason::timeout) {
-    ++clientConn->d_ci.cs->tcpDownstreamTimeouts;
+    ++sender->getClientState().tcpDownstreamTimeouts;
   }
   else if (reason == FailureReason::gaveUp) {
-    ++clientConn->d_ci.cs->tcpGaveUp;
+    ++sender->getClientState().tcpGaveUp;
   }
 
   try {
     if (d_state == State::sendingQueryToBackend) {
-      clientConn->notifyIOError(clientConn, std::move(d_currentQuery.d_idstate), now);
+      sender->notifyIOError(std::move(d_currentQuery.d_idstate), now);
     }
 
     for (auto& query : d_pendingQueries) {
-      clientConn->notifyIOError(clientConn, std::move(query.d_idstate), now);
+      sender->notifyIOError(std::move(query.d_idstate), now);
     }
 
     for (auto& response : d_pendingResponses) {
-      clientConn->notifyIOError(clientConn, std::move(response.second.d_idstate), now);
+      sender->notifyIOError(std::move(response.second.d_idstate), now);
     }
   }
   catch (const std::exception& e) {
@@ -467,8 +461,8 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBa
 {
   d_downstreamFailures = 0;
 
-  auto& clientConn = d_clientConn;
-  if (!clientConn || !clientConn->active()) {
+  auto& sender = d_sender;
+  if (!sender || !sender->active()) {
     // a client timeout occurred, or something like that */
     d_connectionDied = true;
 
@@ -494,11 +488,7 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBa
     return IOState::Done;
   }
 
-  if (!conn->d_usedForXFR) {
-    --conn->d_ds->outstanding;
-  }
-
-  if (d_usedForXFR) {
+  if (it->second.isXFR()) {
     DEBUGLOG("XFR!");
     bool done = false;
     TCPResponse response;
@@ -509,22 +499,22 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBa
     response.d_idstate.qname = it->second.d_idstate.qname;
     DEBUGLOG("passing XFRresponse to client connection for "<<response.d_idstate.qname);
 
-    done = isXFRFinished(response, clientConn);
+    it->second.d_xfrStarted = true;
+    done = isXFRFinished(response, it->second);
 
     if (done) {
       d_pendingResponses.erase(it);
+      --conn->d_ds->outstanding;
       /* marking as idle for now, so we can accept new queries if our queues are empty */
       if (d_pendingQueries.empty() && d_pendingResponses.empty()) {
         d_state = State::idle;
       }
-      clientConn->d_isXFR = false;
-      conn->d_usedForXFR = false;
     }
 
-    clientConn->handleXFRResponse(clientConn, now, std::move(response));
+    sender->handleXFRResponse(now, std::move(response));
     if (done) {
       d_state = State::idle;
-      d_clientConn.reset();
+      d_sender.reset();
       return IOState::Done;
     }
 
@@ -534,6 +524,9 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBa
     // get ready to read the next packet, if any
     return IOState::NeedRead;
   }
+  else {
+    --conn->d_ds->outstanding;
+  }
 
   auto ids = std::move(it->second.d_idstate);
   d_pendingResponses.erase(it);
@@ -543,7 +536,7 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBa
   }
 
   DEBUGLOG("passing response to client connection for "<<ids.qname);
-  clientConn->handleResponse(clientConn, now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn));
+  sender->handleResponse(now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn));
 
   if (!d_pendingQueries.empty()) {
     DEBUGLOG("still have some queries to send");
@@ -563,7 +556,7 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBa
   else {
     DEBUGLOG("nothing to do, waiting for a new query");
     d_state = State::idle;
-    d_clientConn.reset();
+    d_sender.reset();
     return IOState::Done;
   }
 }
@@ -605,7 +598,7 @@ bool TCPConnectionToBackend::matchesTLVs(const std::unique_ptr<std::vector<Proxy
   return *tlvs == *d_proxyProtocolValuesSent;
 }
 
-bool TCPConnectionToBackend::isXFRFinished(const TCPResponse& response, const shared_ptr<IncomingTCPConnectionState>& clientConn)
+bool TCPConnectionToBackend::isXFRFinished(const TCPResponse& response, TCPQuery& query)
 {
   bool done = false;
   try {
@@ -626,20 +619,20 @@ bool TCPConnectionToBackend::isXFRFinished(const TCPResponse& response, const sh
         auto raw = unknownContent->getRawContent();
         auto serial = getSerialFromRawSOAContent(raw);
 
-        ++clientConn->d_xfrSerialCount;
-        if (clientConn->d_xfrMasterSerial == 0) {
+        ++query.d_xfrSerialCount;
+        if (query.d_xfrMasterSerial == 0) {
           // store the first SOA in our client's connection metadata
-          ++clientConn->d_xfrMasterSerialCount;
-          clientConn->d_xfrMasterSerial = serial;
+          ++query.d_xfrMasterSerialCount;
+          query.d_xfrMasterSerial = serial;
         }
-        else if (clientConn->d_xfrMasterSerial == serial) {
-          ++clientConn->d_xfrMasterSerialCount;
+        else if (query.d_xfrMasterSerial == serial) {
+          ++query.d_xfrMasterSerialCount;
           // figure out if it's end when receiving master's SOA again
-          if (clientConn->d_xfrSerialCount == 2) {
+          if (query.d_xfrSerialCount == 2) {
             // if there are only two SOA records marks a finished AXFR
             done = true;
           }
-          if (clientConn->d_xfrMasterSerialCount == 3) {
+          if (query.d_xfrMasterSerialCount == 3) {
             // receiving master's SOA 3 times marks a finished IXFR
             done = true;
           }
index f9d26e51800a6e64fbed8a1b0deda846827428bb..228bee4b3bfcc939147e0f003ddadf8390c5d315 100644 (file)
@@ -5,49 +5,12 @@
 #include "sstuff.hh"
 #include "tcpiohandler-mplexer.hh"
 #include "dnsdist.hh"
+#include "dnsdist-tcp.hh"
 
-struct TCPQuery
-{
-  TCPQuery()
-  {
-  }
-
-  TCPQuery(PacketBuffer&& buffer, IDState&& state): d_idstate(std::move(state)), d_buffer(std::move(buffer))
-  {
-  }
-
-  IDState d_idstate;
-  PacketBuffer d_buffer;
-  std::string d_proxyProtocolPayload;
-  bool d_proxyProtocolPayloadAdded{false};
-};
-
-class TCPConnectionToBackend;
-
-struct TCPResponse : public TCPQuery
-{
-  TCPResponse()
-  {
-    /* let's make Coverity happy */
-    memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
-  }
-
-  TCPResponse(PacketBuffer&& buffer, IDState&& state, std::shared_ptr<TCPConnectionToBackend> conn): TCPQuery(std::move(buffer), std::move(state)), d_connection(conn)
-  {
-    memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
-  }
-
-  std::shared_ptr<TCPConnectionToBackend> d_connection{nullptr};
-  dnsheader d_cleartextDH;
-  bool d_selfGenerated{false};
-};
-
-class IncomingTCPConnectionState;
-
-class TCPConnectionToBackend
+class TCPConnectionToBackend : public std::enable_shared_from_this<TCPConnectionToBackend>
 {
 public:
-  TCPConnectionToBackend(std::shared_ptr<DownstreamState>& ds, const struct timeval& now): d_responseBuffer(s_maxPacketCacheEntrySize), d_ds(ds), d_connectionStartTime(now), d_lastDataReceivedTime(now), d_enableFastOpen(ds->tcpFastOpen)
+  TCPConnectionToBackend(std::shared_ptr<DownstreamState>& ds, std::unique_ptr<FDMultiplexer>& mplexer, const struct timeval& now): d_responseBuffer(s_maxPacketCacheEntrySize), d_mplexer(mplexer), d_ds(ds), d_connectionStartTime(now), d_lastDataReceivedTime(now), d_enableFastOpen(ds->tcpFastOpen)
   {
     reconnect();
   }
@@ -64,8 +27,6 @@ public:
     }
   }
 
-  void assignToClientConnection(std::shared_ptr<IncomingTCPConnectionState>& clientConn, bool isXFR);
-
   int getHandle() const
   {
     if (!d_handler) {
@@ -118,10 +79,8 @@ public:
   /* whether we can accept new queries FOR THE SAME CLIENT */
   bool canAcceptNewQueries() const
   {
-    if (d_usedForXFR || d_connectionDied) {
+    if (d_connectionDied) {
       return false;
-      /* Don't reuse the TCP connection after an {A,I}XFR */
-      /* but don't reset it either, we will need to read more messages */
     }
 
     if ((d_pendingQueries.size() + d_pendingResponses.size()) >= d_ds->d_maxInFlightQueriesPerConn) {
@@ -139,7 +98,7 @@ public:
   /* whether a connection can be reused for a different client */
   bool canBeReused() const
   {
-    if (d_usedForXFR || d_connectionDied) {
+    if (d_connectionDied) {
       return false;
     }
     /* we can't reuse a connection where a proxy protocol payload has been sent,
@@ -163,7 +122,7 @@ public:
     return ds == d_ds;
   }
 
-  void queueQuery(TCPQuery&& query, std::shared_ptr<TCPConnectionToBackend>& sharedSelf);
+  void queueQuery(std::shared_ptr<TCPQuerySender>& sender, TCPQuery&& query);
   void handleTimeout(const struct timeval& now, bool write);
   void release();
 
@@ -177,7 +136,7 @@ public:
   std::string toString() const
   {
     ostringstream o;
-    o << "TCP connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? std::to_string((int)d_ioState->getState()) : "empty")<<", queries count is "<<d_queries<<", pending queries count is "<<d_pendingQueries.size()<<", "<<d_pendingResponses.size()<<" pending responses, linked to "<<(d_clientConn ? " a client" : "no client");
+    o << "TCP connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? std::to_string((int)d_ioState->getState()) : "empty")<<", queries count is "<<d_queries<<", pending queries count is "<<d_pendingQueries.size()<<", "<<d_pendingResponses.size()<<" pending responses, linked to "<<(d_sender ? " a client" : "no client");
     return o.str();
   }
 
@@ -191,7 +150,7 @@ private:
   static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
   static IOState queueNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn);
   static IOState sendQuery(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now);
-  static bool isXFRFinished(const TCPResponse& response, const shared_ptr<IncomingTCPConnectionState>& clientConn);
+  static bool isXFRFinished(const TCPResponse& response, TCPQuery& query);
 
   IOState handleResponse(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now);
   uint16_t getQueryIdFromResponse();
@@ -247,16 +206,15 @@ private:
     return res;
   }
 
-  static const uint16_t s_xfrID;
-
   PacketBuffer d_responseBuffer;
   std::deque<TCPQuery> d_pendingQueries;
   std::unordered_map<uint16_t, TCPQuery> d_pendingResponses;
+  std::unique_ptr<FDMultiplexer>& d_mplexer;
   std::unique_ptr<std::vector<ProxyProtocolValue>> d_proxyProtocolValuesSent{nullptr};
   std::unique_ptr<TCPIOHandler> d_handler{nullptr};
   std::unique_ptr<IOStateHandler> d_ioState{nullptr};
   std::shared_ptr<DownstreamState> d_ds{nullptr};
-  std::shared_ptr<IncomingTCPConnectionState> d_clientConn;
+  std::shared_ptr<TCPQuerySender> d_sender{nullptr};
   TCPQuery d_currentQuery;
   struct timeval d_connectionStartTime;
   struct timeval d_lastDataReceivedTime;
@@ -268,6 +226,5 @@ private:
   bool d_fresh{true};
   bool d_enableFastOpen{false};
   bool d_connectionDied{false};
-  bool d_usedForXFR{false};
   bool d_proxyProtocolPayloadSent{false};
 };
index 5eed0884863b70725235edff0bb1d25a4dae4c0f..7db91dca5e01bd8963d9da5801c338d1efdc0656 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include "dolog.hh"
+#include "dnsdist-tcp.hh"
 
 class TCPClientThreadData
 {
@@ -14,48 +15,7 @@ public:
   std::unique_ptr<FDMultiplexer> mplexer{nullptr};
 };
 
-struct ConnectionInfo
-{
-  ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1)
-  {
-  }
-  ConnectionInfo(ConnectionInfo&& rhs): remote(rhs.remote), cs(rhs.cs), fd(rhs.fd)
-  {
-    rhs.cs = nullptr;
-    rhs.fd = -1;
-  }
-
-  ConnectionInfo(const ConnectionInfo& rhs) = delete;
-  ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete;
-
-  ConnectionInfo& operator=(ConnectionInfo&& rhs)
-  {
-    remote = rhs.remote;
-    cs = rhs.cs;
-    rhs.cs = nullptr;
-    fd = rhs.fd;
-    rhs.fd = -1;
-    return *this;
-  }
-
-  ~ConnectionInfo()
-  {
-    if (fd != -1) {
-      close(fd);
-      fd = -1;
-    }
-
-    if (cs) {
-      --cs->tcpCurrentConnections;
-    }
-  }
-
-  ComboAddress remote;
-  ClientState* cs{nullptr};
-  int fd{-1};
-};
-
-class IncomingTCPConnectionState
+class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this<IncomingTCPConnectionState>
 {
 public:
   IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_ioState(make_unique<IOStateHandler>(threadData.mplexer, d_ci.fd)), d_connectionStartTime(now)
@@ -145,34 +105,35 @@ public:
   std::shared_ptr<TCPConnectionToBackend> getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now);
   void registerActiveDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn);
 
-  std::unique_ptr<FDMultiplexer>& getIOMPlexer() const
-  {
-    return d_threadData.mplexer;
-  }
-
   static size_t clearAllDownstreamConnections();
 
   static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& conn, const struct timeval& now);
   static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
-  static void notifyIOError(std::shared_ptr<IncomingTCPConnectionState>& state, IDState&& query, const struct timeval& now);
+
   static IOState sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response);
   static void queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response);
+static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write);
 
   /* we take a copy of a shared pointer, not a reference, because the initial shared pointer might be released during the handling of the response */
-  static void handleResponse(std::shared_ptr<IncomingTCPConnectionState> state, const struct timeval& now, TCPResponse&& response);
-  static void handleXFRResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response);
-  static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write);
+  void handleResponse(const struct timeval& now, TCPResponse&& response) override;
+  void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override;
+  void notifyIOError(IDState&& query, const struct timeval& now) override;
 
   void terminateClientConnection();
   void queueQuery(TCPQuery&& query);
 
   bool canAcceptNewQueries(const struct timeval& now);
 
-  bool active() const
+  bool active() const override
   {
     return d_ioState != nullptr;
   }
 
+  const ClientState& getClientState() override
+  {
+    return *d_ci.cs;
+  }
+
   std::string toString() const
   {
     ostringstream o;
@@ -203,9 +164,6 @@ public:
   size_t d_proxyProtocolNeed{0};
   size_t d_queriesCount{0};
   size_t d_currentQueriesCount{0};
-  uint32_t d_xfrMasterSerial{0};
-  uint32_t d_xfrSerialCount{0};
-  uint8_t d_xfrMasterSerialCount{0};
   uint16_t d_querySize{0};
   State d_state{State::doingHandshake};
   bool d_isXFR{false};
diff --git a/pdns/dnsdistdist/dnsdist-tcp.hh b/pdns/dnsdistdist/dnsdist-tcp.hh
new file mode 100644 (file)
index 0000000..b932e5d
--- /dev/null
@@ -0,0 +1,296 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+struct ConnectionInfo
+{
+  ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1)
+  {
+  }
+  ConnectionInfo(ConnectionInfo&& rhs): remote(rhs.remote), cs(rhs.cs), fd(rhs.fd)
+  {
+    rhs.cs = nullptr;
+    rhs.fd = -1;
+  }
+
+  ConnectionInfo(const ConnectionInfo& rhs) = delete;
+  ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete;
+
+  ConnectionInfo& operator=(ConnectionInfo&& rhs)
+  {
+    remote = rhs.remote;
+    cs = rhs.cs;
+    rhs.cs = nullptr;
+    fd = rhs.fd;
+    rhs.fd = -1;
+    return *this;
+  }
+
+  ~ConnectionInfo()
+  {
+    if (fd != -1) {
+      close(fd);
+      fd = -1;
+    }
+
+    if (cs) {
+      --cs->tcpCurrentConnections;
+    }
+  }
+
+  ComboAddress remote;
+  ClientState* cs{nullptr};
+  int fd{-1};
+};
+
+struct InternalQuery
+{
+  InternalQuery()
+  {
+  }
+
+  InternalQuery(PacketBuffer&& buffer, IDState&& state): d_idstate(std::move(state)), d_buffer(std::move(buffer))
+  {
+  }
+
+  InternalQuery(InternalQuery&& rhs) :
+    d_idstate(std::move(rhs.d_idstate)), d_buffer(std::move(rhs.d_buffer)), d_proxyProtocolPayload(std::move(rhs.d_proxyProtocolPayload)), d_xfrMasterSerial(rhs.d_xfrMasterSerial), d_xfrSerialCount(rhs.d_xfrSerialCount), d_xfrMasterSerialCount(rhs.d_xfrMasterSerialCount), d_proxyProtocolPayloadAdded(rhs.d_proxyProtocolPayloadAdded)
+  {
+  }
+  InternalQuery& operator=(InternalQuery&& rhs)
+  {
+    d_idstate = std::move(rhs.d_idstate);
+    d_buffer = std::move(rhs.d_buffer);
+    d_proxyProtocolPayload = std::move(rhs.d_proxyProtocolPayload);
+    d_xfrMasterSerial = rhs.d_xfrMasterSerial;
+    d_xfrSerialCount = rhs.d_xfrSerialCount;
+    d_xfrMasterSerialCount = rhs.d_xfrMasterSerialCount;
+    d_proxyProtocolPayloadAdded = rhs.d_proxyProtocolPayloadAdded;
+    return *this;
+  }
+
+  InternalQuery(const InternalQuery& rhs) = delete;
+  InternalQuery& operator=(const InternalQuery& rhs) = delete;
+
+  bool isXFR() const
+  {
+    return d_idstate.qtype == QType::AXFR || d_idstate.qtype == QType::IXFR;
+  }
+
+  IDState d_idstate;
+  PacketBuffer d_buffer;
+  std::string d_proxyProtocolPayload;
+  uint32_t d_xfrMasterSerial{0};
+  uint32_t d_xfrSerialCount{0};
+  uint8_t d_xfrMasterSerialCount{0};
+  bool d_xfrStarted{false};
+  bool d_proxyProtocolPayloadAdded{false};
+};
+
+using TCPQuery = InternalQuery;
+
+class TCPConnectionToBackend;
+
+struct TCPResponse : public TCPQuery
+{
+  TCPResponse()
+  {
+    /* let's make Coverity happy */
+    memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
+  }
+
+  TCPResponse(PacketBuffer&& buffer, IDState&& state, std::shared_ptr<TCPConnectionToBackend> conn): TCPQuery(std::move(buffer), std::move(state)), d_connection(conn)
+  {
+    memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
+  }
+
+  std::shared_ptr<TCPConnectionToBackend> d_connection{nullptr};
+  dnsheader d_cleartextDH;
+  bool d_selfGenerated{false};
+};
+
+class TCPQuerySender
+{
+public:
+  virtual ~TCPQuerySender()
+  {
+  }
+
+  virtual bool active() const = 0;
+  virtual const ClientState& getClientState() = 0;
+  virtual void handleResponse(const struct timeval& now, TCPResponse&& response) = 0;
+  virtual void handleXFRResponse(const struct timeval& now, TCPResponse&& response) = 0;
+  virtual void notifyIOError(IDState&& query, const struct timeval& now) = 0;
+};
+
+struct CrossProtocolQuery
+{
+  CrossProtocolQuery()
+  {
+  }
+
+  CrossProtocolQuery(CrossProtocolQuery&& rhs) = delete;
+  virtual ~CrossProtocolQuery()
+  {
+  }
+
+  virtual std::shared_ptr<TCPQuerySender> getTCPQuerySender() = 0;
+
+  InternalQuery query;
+  std::shared_ptr<DownstreamState> downstream{nullptr};
+};
+
+class TCPClientCollection {
+public:
+  TCPClientCollection(size_t maxThreads);
+
+  int getThread()
+  {
+    if (d_numthreads == 0) {
+      throw std::runtime_error("No TCP worker thread yet");
+    }
+
+    uint64_t pos = d_pos++;
+    ++d_queued;
+    return d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe;
+  }
+
+  bool passConnectionToThread(std::unique_ptr<ConnectionInfo>&& conn)
+  {
+    if (d_numthreads == 0) {
+      throw std::runtime_error("No TCP worker thread yet");
+    }
+
+    uint64_t pos = d_pos++;
+    auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe;
+    auto tmp = conn.release();
+
+    if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) {
+      delete tmp;
+      tmp = nullptr;
+      return false;
+    }
+    ++d_queued;
+    return true;
+  }
+
+  bool passCrossProtocolQueryToThread(std::unique_ptr<CrossProtocolQuery>&& cpq)
+  {
+    if (d_numthreads == 0) {
+      throw std::runtime_error("No TCP worker thread yet");
+    }
+
+    uint64_t pos = d_pos++;
+    auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_crossProtocolQueryPipe;
+    auto tmp = cpq.release();
+
+    if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) {
+      delete tmp;
+      tmp = nullptr;
+      return false;
+    }
+
+    return true;
+  }
+
+  bool hasReachedMaxThreads() const
+  {
+    return d_numthreads >= d_maxthreads;
+  }
+
+  uint64_t getThreadsCount() const
+  {
+    return d_numthreads;
+  }
+
+  uint64_t getQueuedCount() const
+  {
+    return d_queued;
+  }
+
+  void decrementQueuedCount()
+  {
+    --d_queued;
+  }
+
+  void addTCPClientThread();
+
+private:
+  struct TCPWorkerThread
+  {
+    TCPWorkerThread()
+    {
+    }
+
+    TCPWorkerThread(int newConnPipe, int crossProtocolPipe): d_newConnectionPipe(newConnPipe), d_crossProtocolQueryPipe(crossProtocolPipe)
+    {
+    }
+
+    TCPWorkerThread(TCPWorkerThread&& rhs): d_newConnectionPipe(rhs.d_newConnectionPipe), d_crossProtocolQueryPipe(rhs.d_crossProtocolQueryPipe)
+    {
+      rhs.d_newConnectionPipe = -1;
+      rhs.d_crossProtocolQueryPipe = -1;
+    }
+
+    TCPWorkerThread& operator=(TCPWorkerThread&& rhs)
+    {
+      if (d_newConnectionPipe != -1) {
+        close(d_newConnectionPipe);
+      }
+      if (d_crossProtocolQueryPipe != -1) {
+        close(d_crossProtocolQueryPipe);
+      }
+
+      d_newConnectionPipe = rhs.d_newConnectionPipe;
+      d_crossProtocolQueryPipe = rhs.d_crossProtocolQueryPipe;
+      rhs.d_newConnectionPipe = -1;
+      rhs.d_crossProtocolQueryPipe = -1;
+
+      return *this;
+    }
+
+    TCPWorkerThread(const TCPWorkerThread& rhs) = delete;
+    TCPWorkerThread& operator=(const TCPWorkerThread&) = delete;
+
+    ~TCPWorkerThread()
+    {
+      if (d_newConnectionPipe != -1) {
+        close(d_newConnectionPipe);
+      }
+      if (d_crossProtocolQueryPipe != -1) {
+        close(d_crossProtocolQueryPipe);
+      }
+    }
+
+    int d_newConnectionPipe{-1};
+    int d_crossProtocolQueryPipe{-1};
+  };
+
+  std::mutex d_mutex;
+  std::vector<TCPWorkerThread> d_tcpclientthreads;
+  stat_t d_numthreads{0};
+  stat_t d_pos{0};
+  stat_t d_queued{0};
+  const uint64_t d_maxthreads{0};
+};
+
+extern std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
index e89859c3e896eb517b0e78fd891004e35c8fae3a..296c09e6db2774727cd3a633d04f0767e36d3e1e 100644 (file)
@@ -20,6 +20,7 @@
 #include "dnsname.hh"
 #undef CERT
 #include "dnsdist.hh"
+#include "dnsdist-tcp.hh"
 #include "misc.hh"
 #include "dns.hh"
 #include "dolog.hh"
@@ -177,6 +178,11 @@ struct DOHServerConfig
     dohquerypair[0] = fd[1];
     dohquerypair[1] = fd[0];
 
+    setNonBlocking(dohquerypair[0]);
+    if (internalPipeBufferSize > 0) {
+      setPipeBufferSize(dohquerypair[0], internalPipeBufferSize);
+    }
+
     if (pipe(fd) < 0) {
       close(dohquerypair[0]);
       close(dohquerypair[1]);
@@ -186,11 +192,6 @@ struct DOHServerConfig
     dohresponsepair[0] = fd[1];
     dohresponsepair[1] = fd[0];
 
-    setNonBlocking(dohquerypair[0]);
-    if (internalPipeBufferSize > 0) {
-      setPipeBufferSize(dohquerypair[0], internalPipeBufferSize);
-    }
-
     setNonBlocking(dohresponsepair[0]);
     if (internalPipeBufferSize > 0) {
       setPipeBufferSize(dohresponsepair[0], internalPipeBufferSize);
@@ -198,6 +199,14 @@ struct DOHServerConfig
 
     setNonBlocking(dohresponsepair[1]);
 
+    if (pipe(fd) < 0) {
+      close(dohquerypair[0]);
+      close(dohquerypair[1]);
+      close(dohresponsepair[0]);
+      close(dohresponsepair[1]);
+      unixDie("Creating a pipe for DNS over HTTPS");
+    }
+
     h2o_config_init(&h2o_config);
     h2o_config.http2.idle_timeout = idleTimeout * 1000;
   }
@@ -465,13 +474,12 @@ static int processDOHQuery(DOHUnit* du)
     uint16_t qtype, qclass;
     unsigned int qnameWireLength = 0;
     DNSName qname(reinterpret_cast<const char*>(du->query.data()), du->query.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength);
-    DNSQuestion dq(&qname, qtype, qclass, &du->dest, &du->remote, du->query, DNSQuestion::Protocol::DoH, &queryRealTime);
+    DNSQuestion dq(&qname, qtype, qclass, &du->dest, &du->remote, du->query, dnsdist::Protocol::DoH, &queryRealTime);
     dq.ednsAdded = du->ednsAdded;
     dq.du = du;
     dq.sni = std::move(du->sni);
 
-    std::shared_ptr<DownstreamState> ss{nullptr};
-    auto result = processQuery(dq, cs, holders, ss);
+    auto result = processQuery(dq, cs, holders, du->downstream);
 
     if (result == ProcessQueryResult::Drop) {
       du->status_code = 403;
@@ -493,14 +501,14 @@ static int processDOHQuery(DOHUnit* du)
       return -1;
     }
 
-    if (ss == nullptr) {
+    if (du->downstream == nullptr) {
       du->status_code = 502;
       return -1;
     }
 
     ComboAddress dest = du->dest;
-    unsigned int idOffset = (ss->idOffset++) % ss->idStates.size();
-    IDState* ids = &ss->idStates[idOffset];
+    unsigned int idOffset = (du->downstream->idOffset++) % du->downstream->idStates.size();
+    IDState* ids = &du->downstream->idStates[idOffset];
     ids->age = 0;
     DOHUnit* oldDU = nullptr;
     if (ids->isInUse()) {
@@ -516,13 +524,13 @@ static int processDOHQuery(DOHUnit* du)
       /* the state was not in use.
          we reset 'oldDU' because it might have still been in use when we read it. */
       oldDU = nullptr;
-      ++ss->outstanding;
+      ++du->downstream->outstanding;
     }
     else {
       ids->du = nullptr;
       /* we are reusing a state, no change in outstanding but if there was an existing DOHUnit we need
          to handle it because it's about to be overwritten. */
-      ++ss->reuseds;
+      ++du->downstream->reuseds;
       ++g_stats.downstreamTimeouts;
       handleDOHTimeout(oldDU);
     }
@@ -554,16 +562,16 @@ static int processDOHQuery(DOHUnit* du)
       ids->destHarvested = false;
     }
 
-    if (ss->useProxyProtocol) {
+    if (du->downstream->useProxyProtocol) {
       addProxyProtocol(dq);
     }
 
-    int fd = pickBackendSocketForSending(ss);
+    int fd = pickBackendSocketForSending(du->downstream);
     try {
       /* you can't touch du after this line, because it might already have been freed */
-      ssize_t ret = udpClientSendRequestToBackend(ss, fd, du->query);
+      ssize_t ret = udpClientSendRequestToBackend(du->downstream, fd, du->query);
 
-      if(ret < 0) {
+      if (ret < 0) {
         /* we are about to handle the error, make sure that
            this pointer is not accessed when the state is cleaned,
            but first check that it still belongs to us */
@@ -571,9 +579,9 @@ static int processDOHQuery(DOHUnit* du)
           ids->du = nullptr;
           du->release();
           duRefCountIncremented = false;
-          --ss->outstanding;
+          --du->downstream->outstanding;
         }
-        ++ss->sendErrors;
+        ++du->downstream->sendErrors;
         ++g_stats.downstreamSendErrors;
         du->status_code = 502;
         return -1;
@@ -586,7 +594,7 @@ static int processDOHQuery(DOHUnit* du)
       throw;
     }
 
-    vinfolog("Got query for %s|%s from %s (https), relayed to %s", ids->qname.toString(), QType(ids->qtype).toString(), remote.toStringWithPort(), ss->getName());
+    vinfolog("Got query for %s|%s from %s (https), relayed to %s", ids->qname.toString(), QType(ids->qtype).toString(), remote.toStringWithPort(), du->downstream->getName());
   }
   catch(const std::exception& e) {
     vinfolog("Got an error in DOH question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
@@ -1120,6 +1128,94 @@ static void dnsdistclient(int qsock)
   }
 }
 
+class DoHTCPCrossQuerySender : public TCPQuerySender
+{
+public:
+  DoHTCPCrossQuerySender(DOHUnit* du_): du(du_)
+  {
+  }
+
+  ~DoHTCPCrossQuerySender()
+  {
+    if (du != nullptr) {
+      du->release();
+    }
+  }
+
+  bool active() const override
+  {
+    return true;
+  }
+
+  const ClientState& getClientState() override
+  {
+    if (!du || !du->dsc || !du->dsc->cs) {
+      throw std::runtime_error("No query associated to this DoHTCPCrossQuerySender");
+    }
+
+    return *du->dsc->cs;
+  }
+
+  void handleResponse(const struct timeval& now, TCPResponse&& response) override
+  {
+    if (!du) {
+      return;
+    }
+
+    if (du->rsock == -1) {
+      return;
+    }
+
+    du->response = std::move(response.d_buffer);
+
+    auto sent = write(du->rsock, &du, sizeof(du));
+    if (sent != sizeof(du)) {
+      du->release();
+      du = nullptr;
+   }
+  }
+
+  void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
+  {
+    throw std::runtime_error("Oops");
+  }
+
+  void notifyIOError(IDState&& query, const struct timeval& now) override
+  {
+    throw std::runtime_error("Oops");
+  }
+
+private:
+  DOHUnit* du{nullptr};
+};
+
+class DoHCrossProtocolQuery : public CrossProtocolQuery
+{
+public:
+  DoHCrossProtocolQuery(DOHUnit* du_): du(du_)
+  {
+    query = InternalQuery(std::move(du->query), std::move(du->ids));
+    downstream = du->downstream;
+  }
+
+  ~DoHCrossProtocolQuery()
+  {
+    if (du != nullptr) {
+      du->release();
+    }
+  }
+
+  std::shared_ptr<TCPQuerySender> getTCPQuerySender() override
+  {
+    auto sender = std::make_shared<DoHTCPCrossQuerySender>(du);
+    du = nullptr;
+    return sender;
+  }
+
+private:
+  DOHUnit* du{nullptr};
+};
+
 /* Called in the main DoH thread if h2o finds that dnsdist gave us an answer by writing into
    the dohresponsepair[0] side of the pipe so from:
    - handleDOHTimeout() when we did not get a response fast enough (called
@@ -1147,6 +1243,28 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err)
     return;
   }
 
+  if (!du->response.empty() && !du->tcp) {
+    const dnsheader* dh = reinterpret_cast<const struct dnsheader*>(du->response.data());
+
+    if (dh->tc) {
+      /* restoring the original ID */
+      dnsheader* queryDH = reinterpret_cast<struct dnsheader*>(du->query.data());
+      queryDH->id = htons(du->ids.origID);
+
+      auto cpq = std::make_unique<DoHCrossProtocolQuery>(du);
+
+      du->get();
+      du->tcp = true;
+
+      if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) {
+        return;
+      }
+      else {
+        du->release();
+      }
+    }
+  }
+
   if (du->self) {
     // we are back in the h2o main thread now, so we don't risk
     // a race (h2o killing the query) when accessing du->req anymore
@@ -1452,6 +1570,32 @@ void dohThread(ClientState* cs)
   }
 }
 
+void DOHUnit::handleUDPResponse(PacketBuffer&& udpResponse, IDState&& state)
+{
+  static_assert(sizeof(*this) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
+
+  response = std::move(udpResponse);
+  ids = std::move(state);
+
+  auto du = this;
+  ssize_t sent = write(rsock, &du, sizeof(du));
+  if (sent != sizeof(this)) {
+    if (errno == EAGAIN || errno == EWOULDBLOCK) {
+      ++g_stats.dohResponsePipeFull;
+      vinfolog("Unable to pass a DoH response to the DoH worker thread because the pipe is full");
+    }
+    else {
+      vinfolog("Unable to pass a DoH response to the DoH worker thread because we couldn't write to the pipe: %s", stringerror());
+    }
+
+    /* at this point we have the only remaining pointer on this
+       DOHUnit object since we did set ids->du to nullptr earlier,
+       except if we got the response before the pointer could be
+       released by the frontend */
+    release();
+  }
+}
+
 #else /* HAVE_DNS_OVER_HTTPS */
 
 void handleDOHTimeout(DOHUnit* oldDU)
index 316fd782062937d23be4aa057e3cbab0aa6e4fa4..bc013bd0892e7cc7e4df6775c2cb41006c28cf1e 100644 (file)
@@ -307,7 +307,7 @@ BOOST_AUTO_TEST_CASE(test_LMDB) {
   ComboAddress lc("192.0.2.1:53");
   ComboAddress rem("192.0.2.128:42");
   PacketBuffer packet(sizeof(dnsheader));
-  auto proto = DNSQuestion::Protocol::DoUDP;
+  auto proto = dnsdist::Protocol::DoUDP;
   struct timespec queryRealTime;
   gettime(&queryRealTime, true);
   struct timespec expiredTime;
@@ -387,7 +387,7 @@ BOOST_AUTO_TEST_CASE(test_CDB) {
   ComboAddress lc("192.0.2.1:53");
   ComboAddress rem("192.0.2.128:42");
   PacketBuffer packet(sizeof(dnsheader));
-  auto proto = DNSQuestion::Protocol::DoUDP;
+  auto proto = dnsdist::Protocol::DoUDP;
   struct timespec queryRealTime;
   gettime(&queryRealTime, true);
   struct timespec expiredTime;
index b1a31886e3048608d9057126c89993a2ea376f91..d1f5080836098b852516d2eba7145c7032564d54 100644 (file)
@@ -100,7 +100,7 @@ static DNSQuestion getDQ(const DNSName* providedName = nullptr)
 
   uint16_t qtype = QType::A;
   uint16_t qclass = QClass::IN;
-  auto proto = DNSQuestion::Protocol::DoUDP;
+  auto proto = dnsdist::Protocol::DoUDP;
   gettime(&queryRealTime, true);
 
   DNSQuestion dq(providedName ? providedName : &qname, qtype, qclass, &lc, &rem, packet, proto, &queryRealTime);
index 6ae2a6e14bdc4673443f90e116c2a0f324523f06..70ccbd0020e5c0faaf80a5dce958df2671a4712a 100644 (file)
@@ -23,7 +23,7 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) {
   ComboAddress lc("127.0.0.1:53");
   ComboAddress rem("192.0.2.1:42");
   PacketBuffer packet(sizeof(dnsheader));
-  auto proto = DNSQuestion::Protocol::DoUDP;
+  auto proto = dnsdist::Protocol::DoUDP;
   struct timespec queryRealTime;
   gettime(&queryRealTime, true);
   struct timespec expiredTime;
index 4f506fefc1ca517edb1d939dbbc7e84b31f467c5..c1a1398ee8553a40b039f4dacb5a2916a5cf2fc7 100644 (file)
@@ -2647,6 +2647,11 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
         /* the backend descriptor becomes ready */
         dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setReady(desc);
       } },
+      /* no more query from the client for now */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, 0 , [&threadData](int desc, const ExpectedStep& step) {
+        /* the client descriptor becomes NOT ready */
+        dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setNotReady(-1);
+      } },
       /* read the response (1) from the backend  */
       { ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, 2 },
       { ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, axfrResponses.at(0).size() - 2 },
index 58e4cfba8ffb347ec2915b57651aebbc447cca70..9a2d12bb4d81aa76b22aab4777a5c0e79edec055 100644 (file)
@@ -173,7 +173,10 @@ struct DOHUnit
 #else /* HAVE_DNS_OVER_HTTPS */
 #include <unordered_map>
 
+#include "dnsdist-idstate.hh"
+
 struct st_h2o_req_t;
+struct DownstreamState;
 
 struct DOHUnit
 {
@@ -199,9 +202,12 @@ struct DOHUnit
     }
   }
 
+  void handleUDPResponse(PacketBuffer&& response, IDState&& state);
+
   std::vector<std::pair<std::string, std::string>> headers;
   PacketBuffer query;
   PacketBuffer response;
+  IDState ids;
   std::string sni;
   std::string path;
   std::string scheme;
@@ -211,6 +217,7 @@ struct DOHUnit
   st_h2o_req_t* req{nullptr};
   DOHUnit** self{nullptr};
   DOHServerConfig* dsc{nullptr};
+  std::shared_ptr<DownstreamState> downstream{nullptr};
   std::string contentType;
   std::atomic<uint64_t> d_refcnt{1};
   size_t query_at{0};
@@ -224,6 +231,9 @@ struct DOHUnit
   */
   uint16_t status_code{200};
   bool ednsAdded{false};
+  /* whether the query was re-sent to the backend over
+     TCP after receiving a truncated answer over UDP */
+  bool tcp{false};
 
   std::string getHTTPPath() const;
   std::string getHTTPHost() const;
index 2466297be9dce9d26b6fd3b13f26eb072f966293..2d49ad24398a98599f60c0e08f3f686a53e151c1 100644 (file)
@@ -62,7 +62,7 @@ static void validateECS(const PacketBuffer& packet, const ComboAddress& expected
   uint16_t qtype;
   uint16_t qclass;
   DNSName qname(reinterpret_cast<const char*>(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-  DNSQuestion dq(&qname, qtype, qclass, nullptr, &rem, const_cast<PacketBuffer&>(packet), DNSQuestion::Protocol::DoUDP, nullptr);
+  DNSQuestion dq(&qname, qtype, qclass, nullptr, &rem, const_cast<PacketBuffer&>(packet), dnsdist::Protocol::DoUDP, nullptr);
   BOOST_CHECK(parseEDNSOptions(dq));
   BOOST_REQUIRE(dq.ednsOptions != nullptr);
   BOOST_CHECK_EQUAL(dq.ednsOptions->size(), 1U);
@@ -113,7 +113,7 @@ BOOST_AUTO_TEST_CASE(test_addXPF)
     BOOST_CHECK_EQUAL(qname, name);
     BOOST_CHECK(qtype == QType::A);
 
-    DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, DNSQuestion::Protocol::DoUDP, &queryTime);
+    DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime);
 
     BOOST_CHECK(addXPF(dq, xpfOptionCode));
     BOOST_CHECK(packet.size() > query.size());
@@ -132,7 +132,7 @@ BOOST_AUTO_TEST_CASE(test_addXPF)
     BOOST_CHECK_EQUAL(qname, name);
     BOOST_CHECK(qtype == QType::A);
 
-    DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, DNSQuestion::Protocol::DoUDP, &queryTime);
+    DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime);
 
     BOOST_REQUIRE(!addXPF(dq, xpfOptionCode));
     BOOST_CHECK_EQUAL(packet.size(), 4096U);
@@ -150,7 +150,7 @@ BOOST_AUTO_TEST_CASE(test_addXPF)
     BOOST_CHECK_EQUAL(qname, name);
     BOOST_CHECK(qtype == QType::A);
 
-    DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, DNSQuestion::Protocol::DoUDP, &queryTime);
+    DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime);
 
     /* add trailing data */
     const size_t trailingDataSize = 10;
@@ -337,7 +337,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed)
   BOOST_CHECK(qtype == QType::A);
   BOOST_CHECK(qclass == QClass::IN);
 
-  DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr);
+  DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr);
   /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */
   BOOST_CHECK(!parseEDNSOptions(dq));
 
@@ -360,7 +360,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed)
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
   BOOST_CHECK(qclass == QClass::IN);
-  DNSQuestion dq2(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr);
+  DNSQuestion dq2(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr);
 
   BOOST_CHECK(handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded));
   BOOST_CHECK_GT(packet.size(), query.size());
@@ -439,7 +439,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) {
   BOOST_CHECK(qtype == QType::A);
   BOOST_CHECK(qclass == QClass::IN);
 
-  DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr);
+  DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr);
   /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */
   BOOST_CHECK(parseEDNSOptions(dq));
 
@@ -461,7 +461,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
   BOOST_CHECK(qclass == QClass::IN);
-  DNSQuestion dq2(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr);
+  DNSQuestion dq2(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr);
 
   BOOST_CHECK(handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded));
   BOOST_CHECK_GT(packet.size(), query.size());
@@ -537,7 +537,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSizeAlreadyParsed) {
   BOOST_CHECK(qtype == QType::A);
   BOOST_CHECK(qclass == QClass::IN);
 
-  DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr);
+  DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr);
   dq.ecsOverride = true;
 
   /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */
@@ -1430,7 +1430,7 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenLastOption) {
 
 static DNSQuestion getDNSQuestion(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& lc, const ComboAddress& rem, const struct timespec& realTime, PacketBuffer& query)
 {
-  return DNSQuestion(&qname, qtype, qclass, &lc, &rem, query, DNSQuestion::Protocol::DoUDP, &realTime);
+  return DNSQuestion(&qname, qtype, qclass, &lc, &rem, query, dnsdist::Protocol::DoUDP, &realTime);
 }
 
 static DNSQuestion turnIntoResponse(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& lc, const ComboAddress& rem, const struct timespec& queryRealTime, PacketBuffer&  query, bool resizeBuffer=true)
@@ -1933,7 +1933,7 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) {
     unsigned int consumed = 0;
     uint16_t qtype;
     DNSName qname(reinterpret_cast<const char*>(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed);
-    DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, DNSQuestion::Protocol::DoUDP, &queryTime);
+    DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime);
 
     BOOST_CHECK(setNegativeAndAdditionalSOA(dq, true, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5));
     BOOST_CHECK(packet.size() > query.size());
@@ -1957,7 +1957,7 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) {
     unsigned int consumed = 0;
     uint16_t qtype;
     DNSName qname(reinterpret_cast<const char*>(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed);
-    DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, DNSQuestion::Protocol::DoUDP, &queryTime);
+    DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime);
 
     BOOST_CHECK(setNegativeAndAdditionalSOA(dq, true, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5));
     BOOST_CHECK(packet.size() > queryWithEDNS.size());
@@ -1985,7 +1985,7 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) {
     unsigned int consumed = 0;
     uint16_t qtype;
     DNSName qname(reinterpret_cast<const char*>(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed);
-    DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, DNSQuestion::Protocol::DoUDP, &queryTime);
+    DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime);
 
     BOOST_CHECK(setNegativeAndAdditionalSOA(dq, false, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5));
     BOOST_CHECK(packet.size() > query.size());
@@ -2009,7 +2009,7 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) {
     unsigned int consumed = 0;
     uint16_t qtype;
     DNSName qname(reinterpret_cast<const char*>(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed);
-    DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, DNSQuestion::Protocol::DoUDP, &queryTime);
+    DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime);
 
     BOOST_CHECK(setNegativeAndAdditionalSOA(dq, false, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5));
     BOOST_CHECK(packet.size() > queryWithEDNS.size());
@@ -2050,7 +2050,7 @@ BOOST_AUTO_TEST_CASE(getEDNSOptionsWithoutEDNS) {
     uint16_t qtype;
     uint16_t qclass;
     DNSName qname(reinterpret_cast<const char*>(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-    DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr);
+    DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr);
 
     BOOST_CHECK(!parseEDNSOptions(dq));
   }
@@ -2071,7 +2071,7 @@ BOOST_AUTO_TEST_CASE(getEDNSOptionsWithoutEDNS) {
     uint16_t qtype;
     uint16_t qclass;
     DNSName qname(reinterpret_cast<const char*>(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-    DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr);
+    DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr);
 
     BOOST_CHECK(!parseEDNSOptions(dq));
   }
@@ -2092,7 +2092,7 @@ BOOST_AUTO_TEST_CASE(getEDNSOptionsWithoutEDNS) {
     uint16_t qtype;
     uint16_t qclass;
     DNSName qname(reinterpret_cast<const char*>(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-    DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr);
+    DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr);
 
     BOOST_CHECK(!parseEDNSOptions(dq));
   }
index 40759684df1c0cd78f782b6529d69f2028d5c9db..7af31fa110aaf0ec47aeb91a731eff18534a43bf 100644 (file)
@@ -50,7 +50,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) {
 
       uint32_t key = 0;
       boost::optional<Netmask> subnet;
-      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime);
+      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime);
       bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP);
       BOOST_CHECK_EQUAL(found, false);
       BOOST_CHECK(!subnet);
@@ -81,7 +81,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) {
       pwQ.getHeader()->rd = 1;
       uint32_t key = 0;
       boost::optional<Netmask> subnet;
-      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime);
+      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime);
       bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP);
       if (found == true) {
         auto removed = PC.expungeByName(a);
@@ -100,7 +100,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) {
       pwQ.getHeader()->rd = 1;
       uint32_t key = 0;
       boost::optional<Netmask> subnet;
-      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime);
+      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime);
       if (PC.get(dq, pwQ.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP)) {
         matches++;
       }
@@ -161,7 +161,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSharded) {
 
       uint32_t key = 0;
       boost::optional<Netmask> subnet;
-      DNSQuestion dq(&a, QType::AAAA, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime);
+      DNSQuestion dq(&a, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime);
       bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP);
       BOOST_CHECK_EQUAL(found, false);
       BOOST_CHECK(!subnet);
@@ -192,7 +192,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSharded) {
       pwQ.getHeader()->rd = 1;
       uint32_t key = 0;
       boost::optional<Netmask> subnet;
-      DNSQuestion dq(&a, QType::AAAA, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime);
+      DNSQuestion dq(&a, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime);
       if (PC.get(dq, pwQ.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP)) {
         matches++;
       }
@@ -257,7 +257,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTCP) {
       /* UDP */
       uint32_t key = 0;
       boost::optional<Netmask> subnet;
-      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime);
+      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime);
       bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP);
       BOOST_CHECK_EQUAL(found, false);
       BOOST_CHECK(!subnet);
@@ -272,7 +272,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTCP) {
       /* same but over TCP */
       uint32_t key = 0;
       boost::optional<Netmask> subnet;
-      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoTCP, &queryTime);
+      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoTCP, &queryTime);
       bool found = PC.get(dq, 0, &key, subnet, dnssecOK, !receivedOverUDP);
       BOOST_CHECK_EQUAL(found, false);
       BOOST_CHECK(!subnet);
@@ -316,7 +316,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheServFailTTL) {
 
     uint32_t key = 0;
     boost::optional<Netmask> subnet;
-    DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime);
+    DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime);
     bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP);
     BOOST_CHECK_EQUAL(found, false);
     BOOST_CHECK(!subnet);
@@ -369,7 +369,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNoDataTTL) {
 
     uint32_t key = 0;
     boost::optional<Netmask> subnet;
-    DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime);
+    DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime);
     bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP);
     BOOST_CHECK_EQUAL(found, false);
     BOOST_CHECK(!subnet);
@@ -421,7 +421,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNXDomainTTL) {
 
     uint32_t key = 0;
     boost::optional<Netmask> subnet;
-    DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime);
+    DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime);
     bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP);
     BOOST_CHECK_EQUAL(found, false);
     BOOST_CHECK(!subnet);
@@ -470,7 +470,7 @@ static void threadMangler(unsigned int offset)
 
       uint32_t key = 0;
       boost::optional<Netmask> subnet;
-      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime);
+      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime);
       g_PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP);
 
       g_PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none);
@@ -500,7 +500,7 @@ static void threadReader(unsigned int offset)
 
       uint32_t key = 0;
       boost::optional<Netmask> subnet;
-      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime);
+      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime);
       bool found = g_PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP);
       if (!found) {
        g_missing++;
@@ -576,7 +576,7 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) {
     ComboAddress remote("192.0.2.1");
     struct timespec queryTime;
     gettime(&queryTime);
-    DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime);
+    DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime);
     bool found = PC.get(dq, 0, &key, subnetOut, dnssecOK, receivedOverUDP);
     BOOST_CHECK_EQUAL(found, false);
     BOOST_REQUIRE(subnetOut);
@@ -619,7 +619,7 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) {
     ComboAddress remote("192.0.2.1");
     struct timespec queryTime;
     gettime(&queryTime);
-    DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime);
+    DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime);
     bool found = PC.get(dq, 0, &secondKey, subnetOut, dnssecOK, receivedOverUDP);
     BOOST_CHECK_EQUAL(found, false);
     BOOST_CHECK_EQUAL(secondKey, key);
@@ -695,7 +695,7 @@ BOOST_AUTO_TEST_CASE(test_PCDNSSECCollision) {
     ComboAddress remote("192.0.2.1");
     struct timespec queryTime;
     gettime(&queryTime);
-    DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime);
+    DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime);
     bool found = PC.get(dq, 0, &key, subnetOut, true, receivedOverUDP);
     BOOST_CHECK_EQUAL(found, false);