]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Process responses in the right thread for incoming TCP/DoT queries
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 30 Aug 2021 15:04:01 +0000 (17:04 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 13 Sep 2021 13:28:28 +0000 (15:28 +0200)
15 files changed:
pdns/bpf-filter.cc
pdns/bpf-filter.hh
pdns/bpf-filter.main.ebpf
pdns/bpf-filter.qname.ebpf
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdistdist/dnsdist-healthchecks.cc
pdns/dnsdistdist/dnsdist-nghttp2.cc
pdns/dnsdistdist/dnsdist-nghttp2.hh
pdns/dnsdistdist/dnsdist-tcp-upstream.hh
pdns/dnsdistdist/dnsdist-tcp.hh
pdns/dnsdistdist/doh.cc
pdns/misc.hh
pdns/sstuff.hh
regression-tests.dnsdist/dnsdisttests.py

index 59912dbeb16b76f35d259180186d6a902092b3b3..f6fdceebba0516c8106cc31bd0c6c93aeea9391c 100644 (file)
@@ -142,23 +142,23 @@ struct QNameValue
 BPFFilter::BPFFilter(uint32_t maxV4Addresses, uint32_t maxV6Addresses, uint32_t maxQNames): d_maps(Maps()), d_maxV4(maxV4Addresses), d_maxV6(maxV6Addresses), d_maxQNames(maxQNames)
 {
   auto maps = d_maps.lock();
-  maps->d_v4map.fd = bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(uint32_t), sizeof(uint64_t), (int) maxV4Addresses);
-  if (maps->d_v4map.fd == -1) {
+  maps->d_v4map = FDWrapper(bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(uint32_t), sizeof(uint64_t), (int) maxV4Addresses));
+  if (maps->d_v4map.getHandle() == -1) {
     throw std::runtime_error("Error creating a BPF v4 map of size " + std::to_string(maxV4Addresses) + ": " + stringerror());
   }
 
-  maps->d_v6map.fd = bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(struct KeyV6), sizeof(uint64_t), (int) maxV6Addresses);
-  if (maps->d_v6map.fd == -1) {
+  maps->d_v6map = FDWrapper(bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(struct KeyV6), sizeof(uint64_t), (int) maxV6Addresses));
+  if (maps->d_v6map.getHandle() == -1) {
     throw std::runtime_error("Error creating a BPF v6 map of size " + std::to_string(maxV6Addresses) + ": " + stringerror());
   }
 
-  maps->d_qnamemap.fd = bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(struct QNameKey), sizeof(struct QNameValue), (int) maxQNames);
-  if (maps->d_qnamemap.fd == -1) {
+  maps->d_qnamemap = FDWrapper(bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(struct QNameKey), sizeof(struct QNameValue), (int) maxQNames));
+  if (maps->d_qnamemap.getHandle() == -1) {
     throw std::runtime_error("Error creating a BPF qname map of size " + std::to_string(maxQNames) + ": " + stringerror());
   }
 
-  maps->d_filtermap.fd = bpf_create_map(BPF_MAP_TYPE_PROG_ARRAY, sizeof(uint32_t), sizeof(uint32_t), 1);
-  if (maps->d_filtermap.fd == -1) {
+  maps->d_filtermap = FDWrapper(bpf_create_map(BPF_MAP_TYPE_PROG_ARRAY, sizeof(uint32_t), sizeof(uint32_t), 1));
+  if (maps->d_filtermap.getHandle() == -1) {
     throw std::runtime_error("Error creating a BPF program map of size 1: " + stringerror());
   }
 
@@ -166,12 +166,12 @@ BPFFilter::BPFFilter(uint32_t maxV4Addresses, uint32_t maxV6Addresses, uint32_t
 #include "bpf-filter.main.ebpf"
   };
 
-  d_mainfilter.fd = bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER,
-                                  main_filter,
-                                  sizeof(main_filter),
-                                  "GPL",
-                                  0);
-  if (d_mainfilter.fd == -1) {
+  d_mainfilter = FDWrapper(bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER,
+                                         main_filter,
+                                         sizeof(main_filter),
+                                         "GPL",
+                                         0));
+  if (d_mainfilter.getHandle() == -1) {
     throw std::runtime_error("Error loading BPF main filter: " + stringerror());
   }
 
@@ -179,17 +179,18 @@ BPFFilter::BPFFilter(uint32_t maxV4Addresses, uint32_t maxV6Addresses, uint32_t
 #include "bpf-filter.qname.ebpf"
   };
 
-  d_qnamefilter.fd = bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER,
-                                   qname_filter,
-                                   sizeof(qname_filter),
-                                   "GPL",
-                                   0);
-  if (d_qnamefilter.fd == -1) {
+  d_qnamefilter = FDWrapper(bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER,
+                                          qname_filter,
+                                          sizeof(qname_filter),
+                                          "GPL",
+                                          0));
+  if (d_qnamefilter.getHandle() == -1) {
     throw std::runtime_error("Error loading BPF qname filter: " + stringerror());
   }
 
   uint32_t key = 0;
-  int res = bpf_update_elem(maps->d_filtermap.fd, &key, &d_qnamefilter.fd, BPF_ANY);
+  int qnamefd = d_qnamefilter.getHandle();
+  int res = bpf_update_elem(maps->d_filtermap.getHandle(), &key, &qnamefd, BPF_ANY);
   if (res != 0) {
     throw std::runtime_error("Error updating BPF filters map: " + stringerror());
   }
@@ -197,7 +198,8 @@ BPFFilter::BPFFilter(uint32_t maxV4Addresses, uint32_t maxV6Addresses, uint32_t
 
 void BPFFilter::addSocket(int sock)
 {
-  int res = setsockopt(sock, SOL_SOCKET, SO_ATTACH_BPF, &d_mainfilter.fd, sizeof(d_mainfilter.fd));
+  int fd = d_mainfilter.getHandle();
+  int res = setsockopt(sock, SOL_SOCKET, SO_ATTACH_BPF, &fd, sizeof(fd));
 
   if (res != 0) {
     throw std::runtime_error("Error attaching BPF filter to this socket: " + stringerror());
@@ -206,7 +208,8 @@ void BPFFilter::addSocket(int sock)
 
 void BPFFilter::removeSocket(int sock)
 {
-  int res = setsockopt(sock, SOL_SOCKET, SO_DETACH_BPF, &d_mainfilter.fd, sizeof(d_mainfilter.fd));
+  int fd = d_mainfilter.getHandle();
+  int res = setsockopt(sock, SOL_SOCKET, SO_DETACH_BPF, &fd, sizeof(fd));
 
   if (res != 0) {
     throw std::runtime_error("Error detaching BPF filter from this socket: " + stringerror());
@@ -224,12 +227,12 @@ void BPFFilter::block(const ComboAddress& addr)
       throw std::runtime_error("Table full when trying to block " + addr.toString());
     }
 
-    res = bpf_lookup_elem(maps->d_v4map.fd, &key, &counter);
+    res = bpf_lookup_elem(maps->d_v4map.getHandle(), &key, &counter);
     if (res != -1) {
       throw std::runtime_error("Trying to block an already blocked address: " + addr.toString());
     }
 
-    res = bpf_update_elem(maps->d_v4map.fd, &key, &counter, BPF_NOEXIST);
+    res = bpf_update_elem(maps->d_v4map.getHandle(), &key, &counter, BPF_NOEXIST);
     if (res == 0) {
       maps->d_v4Count++;
     }
@@ -246,12 +249,12 @@ void BPFFilter::block(const ComboAddress& addr)
       throw std::runtime_error("Table full when trying to block " + addr.toString());
     }
 
-    res = bpf_lookup_elem(maps->d_v6map.fd, &key, &counter);
+    res = bpf_lookup_elem(maps->d_v6map.getHandle(), &key, &counter);
     if (res != -1) {
       throw std::runtime_error("Trying to block an already blocked address: " + addr.toString());
     }
 
-    res = bpf_update_elem(maps->d_v6map.fd, key, &counter, BPF_NOEXIST);
+    res = bpf_update_elem(maps->d_v6map.getHandle(), key, &counter, BPF_NOEXIST);
     if (res == 0) {
       maps->d_v6Count++;
     }
@@ -268,7 +271,7 @@ void BPFFilter::unblock(const ComboAddress& addr)
   if (addr.isIPv4()) {
     uint32_t key = htonl(addr.sin4.sin_addr.s_addr);
     auto maps = d_maps.lock();
-    res = bpf_delete_elem(maps->d_v4map.fd, &key);
+    res = bpf_delete_elem(maps->d_v4map.getHandle(), &key);
     if (res == 0) {
       maps->d_v4Count--;
     }
@@ -281,7 +284,7 @@ void BPFFilter::unblock(const ComboAddress& addr)
     }
 
     auto maps = d_maps.lock();
-    res = bpf_delete_elem(maps->d_v6map.fd, key);
+    res = bpf_delete_elem(maps->d_v6map.getHandle(), key);
     if (res == 0) {
       maps->d_v6Count--;
     }
@@ -313,12 +316,12 @@ void BPFFilter::block(const DNSName& qname, uint16_t qtype)
       throw std::runtime_error("Table full when trying to block " + qname.toLogString());
     }
 
-    int res = bpf_lookup_elem(maps->d_qnamemap.fd, &key, &value);
+    int res = bpf_lookup_elem(maps->d_qnamemap.getHandle(), &key, &value);
     if (res != -1) {
       throw std::runtime_error("Trying to block an already blocked qname: " + qname.toLogString());
     }
 
-    res = bpf_update_elem(maps->d_qnamemap.fd, &key, &value, BPF_NOEXIST);
+    res = bpf_update_elem(maps->d_qnamemap.getHandle(), &key, &value, BPF_NOEXIST);
     if (res == 0) {
       maps->d_qNamesCount++;
     }
@@ -342,7 +345,7 @@ void BPFFilter::unblock(const DNSName& qname, uint16_t qtype)
 
   {
     auto maps = d_maps.lock();
-    int res = bpf_delete_elem(maps->d_qnamemap.fd, &key);
+    int res = bpf_delete_elem(maps->d_qnamemap.getHandle(), &key);
     if (res == 0) {
       maps->d_qNamesCount--;
     }
@@ -378,28 +381,28 @@ std::vector<std::pair<ComboAddress, uint64_t> > BPFFilter::getAddrStats()
   memset(&v6Key, 0, sizeof(v6Key));
 
   auto maps = d_maps.lock();
-  int res = bpf_get_next_key(maps->d_v4map.fd, &v4Key, &nextV4Key);
+  int res = bpf_get_next_key(maps->d_v4map.getHandle(), &v4Key, &nextV4Key);
 
   while (res == 0) {
     v4Key = nextV4Key;
-    if (bpf_lookup_elem(maps->d_v4map.fd, &v4Key, &value) == 0) {
+    if (bpf_lookup_elem(maps->d_v4map.getHandle(), &v4Key, &value) == 0) {
       v4Addr.sin_addr.s_addr = ntohl(v4Key);
       result.push_back(make_pair(ComboAddress(&v4Addr), value));
     }
 
-    res = bpf_get_next_key(maps->d_v4map.fd, &v4Key, &nextV4Key);
+    res = bpf_get_next_key(maps->d_v4map.getHandle(), &v4Key, &nextV4Key);
   }
 
-  res = bpf_get_next_key(maps->d_v6map.fd, &v6Key, &nextV6Key);
+  res = bpf_get_next_key(maps->d_v6map.getHandle(), &v6Key, &nextV6Key);
 
   while (res == 0) {
-    if (bpf_lookup_elem(maps->d_v6map.fd, &nextV6Key, &value) == 0) {
+    if (bpf_lookup_elem(maps->d_v6map.getHandle(), &nextV6Key, &value) == 0) {
       memcpy(&v6Addr.sin6_addr.s6_addr, &nextV6Key, sizeof(nextV6Key));
 
       result.push_back(make_pair(ComboAddress(&v6Addr), value));
     }
 
-    res = bpf_get_next_key(maps->d_v6map.fd, &nextV6Key, &nextV6Key);
+    res = bpf_get_next_key(maps->d_v6map.getHandle(), &nextV6Key, &nextV6Key);
   }
   return result;
 }
@@ -414,15 +417,15 @@ std::vector<std::tuple<DNSName, uint16_t, uint64_t> > BPFFilter::getQNameStats()
 
   auto maps = d_maps.lock();
   result.reserve(maps->d_qNamesCount);
-  int res = bpf_get_next_key(maps->d_qnamemap.fd, &key, &nextKey);
+  int res = bpf_get_next_key(maps->d_qnamemap.getHandle(), &key, &nextKey);
 
   while (res == 0) {
-    if (bpf_lookup_elem(maps->d_qnamemap.fd, &nextKey, &value) == 0) {
+    if (bpf_lookup_elem(maps->d_qnamemap.getHandle(), &nextKey, &value) == 0) {
       nextKey.qname[sizeof(nextKey.qname) - 1 ] = '\0';
       result.push_back(std::make_tuple(DNSName((const char*) nextKey.qname, sizeof(nextKey.qname), 0, false), value.qtype, value.counter));
     }
 
-    res = bpf_get_next_key(maps->d_qnamemap.fd, &nextKey, &nextKey);
+    res = bpf_get_next_key(maps->d_qnamemap.getHandle(), &nextKey, &nextKey);
   }
   return result;
 }
@@ -434,7 +437,7 @@ uint64_t BPFFilter::getHits(const ComboAddress& requestor)
     uint32_t key = htonl(requestor.sin4.sin_addr.s_addr);
 
     auto maps = d_maps.lock();
-    int res = bpf_lookup_elem(maps->d_v4map.fd, &key, &counter);
+    int res = bpf_lookup_elem(maps->d_v4map.getHandle(), &key, &counter);
     if (res == 0) {
       return counter;
     }
@@ -447,7 +450,7 @@ uint64_t BPFFilter::getHits(const ComboAddress& requestor)
     }
 
     auto maps = d_maps.lock();
-    int res = bpf_lookup_elem(maps->d_v6map.fd, &key, &counter);
+    int res = bpf_lookup_elem(maps->d_v6map.getHandle(), &key, &counter);
     if (res == 0) {
       return counter;
     }
index abf805e4c2cccf6a4edd005394336efa24c92f72..482b61e31abd6e6b16614e0b0d69b425ee533793 100644 (file)
@@ -41,16 +41,6 @@ public:
 
 private:
 #ifdef HAVE_EBPF
-  struct FDWrapper
-  {
-    ~FDWrapper()
-    {
-      if (fd != -1) {
-        close(fd);
-      }
-    }
-    int fd{-1};
-  };
   struct Maps
   {
     FDWrapper d_v4map;
index a8ffa37ad1f1d52fc5170387ddc6f53ebebc088d..8a82f3bce4723e43f20d4fd4bf26b3dded8224bb 100644 (file)
@@ -6,7 +6,7 @@ BPF_JMP_IMM(BPF_JEQ,BPF_REG_1,ntohs(0x86dd),11),
 BPF_JMP_IMM(BPF_JNE,BPF_REG_1,ntohs(0x0800),109),
 BPF_LD_ABS(BPF_W,-2097126),
 BPF_STX_MEM(BPF_W,BPF_REG_10,BPF_REG_0,-256),
-BPF_LD_MAP_FD(BPF_REG_1,maps->d_v4map.fd),
+BPF_LD_MAP_FD(BPF_REG_1,maps->d_v4map.getHandle()),
 BPF_MOV64_REG(BPF_REG_2,BPF_REG_10),
 BPF_ALU64_IMM(BPF_ADD,BPF_REG_2,-256),
 BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_map_lookup_elem),
@@ -45,7 +45,7 @@ BPF_LD_ABS(BPF_B,-2097116),
 BPF_STX_MEM(BPF_B,BPF_REG_10,BPF_REG_0,-242),
 BPF_LD_ABS(BPF_B,-2097115),
 BPF_STX_MEM(BPF_B,BPF_REG_10,BPF_REG_0,-241),
-BPF_LD_MAP_FD(BPF_REG_1,maps->d_v6map.fd),
+BPF_LD_MAP_FD(BPF_REG_1,maps->d_v6map.getHandle()),
 BPF_MOV64_REG(BPF_REG_2,BPF_REG_10),
 BPF_ALU64_IMM(BPF_ADD,BPF_REG_2,-256),
 BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_map_lookup_elem),
@@ -97,7 +97,7 @@ BPF_JMP_IMM(BPF_JGT,BPF_REG_8,63,17),
 BPF_JMP_IMM(BPF_JNE,BPF_REG_8,0,18),
 BPF_LD_ABS(BPF_H,21),
 BPF_MOV64_REG(BPF_REG_6,BPF_REG_0),
-BPF_LD_MAP_FD(BPF_REG_1,maps->d_qnamemap.fd),
+BPF_LD_MAP_FD(BPF_REG_1,maps->d_qnamemap.getHandle()),
 BPF_MOV64_REG(BPF_REG_2,BPF_REG_10),
 BPF_ALU64_IMM(BPF_ADD,BPF_REG_2,-256),
 BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_map_lookup_elem),
@@ -128,7 +128,7 @@ BPF_ALU64_IMM(BPF_ADD,BPF_REG_8,-1),
 BPF_STX_MEM(BPF_W,BPF_REG_6,BPF_REG_8,60),
 BPF_ALU64_IMM(BPF_AND,BPF_REG_1,255),
 BPF_STX_MEM(BPF_W,BPF_REG_6,BPF_REG_1,56),
-BPF_LD_MAP_FD(BPF_REG_2,maps->d_filtermap.fd),
+BPF_LD_MAP_FD(BPF_REG_2,maps->d_filtermap.getHandle()),
 BPF_MOV64_REG(BPF_REG_1,BPF_REG_6),
 BPF_MOV64_IMM(BPF_REG_3,0),
 BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_tail_call),
index febb7b862c47283f5963f99fd47b98440b53b28d..7eb0519bd10e092e0266f2db5233cf06de7ed66d 100644 (file)
@@ -4078,7 +4078,7 @@ BPF_MOV64_IMM(BPF_REG_9,255),
 BPF_ALU64_REG(BPF_ADD,BPF_REG_9,BPF_REG_7),
 BPF_RAW_INSN(BPF_LD|BPF_IND|BPF_H,BPF_REG_0,BPF_REG_9,0,0),
 BPF_MOV64_REG(BPF_REG_6,BPF_REG_0),
-BPF_LD_MAP_FD(BPF_REG_1,maps->d_qnamemap.fd),
+BPF_LD_MAP_FD(BPF_REG_1,maps->d_qnamemap.getHandle()),
 BPF_MOV64_REG(BPF_REG_2,BPF_REG_10),
 BPF_ALU64_IMM(BPF_ADD,BPF_REG_2,-256),
 BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_map_lookup_elem),
index 1b478530c95405bb865076e0110029eedaf71cff..71dac2f47c7953ee2b1d134827bad25f9ce5e433 100644 (file)
@@ -124,10 +124,13 @@ std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstrea
   return downstream;
 }
 
-static void tcpClientThread(int pipefd, int crossProtocolPipeFD);
+static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int crossProtocolResponsesListenPipeFD, int crossProtocolResponsesWritePipeFD);
 
 TCPClientCollection::TCPClientCollection(size_t maxThreads): d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads)
 {
+  for (size_t idx = 0; idx < maxThreads; idx++) {
+    addTCPClientThread();
+  }
 }
 
 void TCPClientCollection::addTCPClientThread()
@@ -166,20 +169,25 @@ void TCPClientCollection::addTCPClientThread()
     return;
   }
 
-  int crossProtocolFDs[2] = { -1, -1};
-  if (!preparePipe(crossProtocolFDs, "cross-protocol")) {
+  int crossProtocolQueriesFDs[2] = { -1, -1};
+  if (!preparePipe(crossProtocolQueriesFDs, "cross-protocol queries")) {
+    return;
+  }
+
+  int crossProtocolResponsesFDs[2] = { -1, -1};
+  if (!preparePipe(crossProtocolResponsesFDs, "cross-protocol responses")) {
     return;
   }
 
   vinfolog("Adding TCP Client thread");
 
   {
-    std::lock_guard<std::mutex> lock(d_mutex);
-
     if (d_numthreads >= d_tcpclientthreads.size()) {
       vinfolog("Adding a new TCP client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of TCP client threads with setMaxTCPClientThreads() in the configuration.", d_numthreads.load(), d_tcpclientthreads.size());
-      close(crossProtocolFDs[0]);
-      close(crossProtocolFDs[1]);
+      close(crossProtocolQueriesFDs[0]);
+      close(crossProtocolQueriesFDs[1]);
+      close(crossProtocolResponsesFDs[0]);
+      close(crossProtocolResponsesFDs[1]);
       close(pipefds[0]);
       close(pipefds[1]);
       return;
@@ -187,15 +195,17 @@ void TCPClientCollection::addTCPClientThread()
 
     /* from now on this side of the pipe will be managed by that object,
        no need to worry about it */
-    TCPWorkerThread worker(pipefds[1], crossProtocolFDs[1]);
+    TCPWorkerThread worker(pipefds[1], crossProtocolQueriesFDs[1], crossProtocolResponsesFDs[1]);
     try {
-      std::thread t1(tcpClientThread, pipefds[0], crossProtocolFDs[0]);
+      std::thread t1(tcpClientThread, pipefds[0], crossProtocolQueriesFDs[0], crossProtocolResponsesFDs[0], crossProtocolResponsesFDs[1]);
       t1.detach();
     }
     catch (const std::runtime_error& e) {
       /* the thread creation failed, don't leak */
       errlog("Error creating a TCP thread: %s", e.what());
       close(pipefds[0]);
+      close(crossProtocolQueriesFDs[0]);
+      close(crossProtocolResponsesFDs[0]);
       return;
     }
 
@@ -471,10 +481,74 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe
   queueResponse(state, now, std::move(response));
 }
 
+struct TCPCrossProtocolResponse
+{
+  TCPCrossProtocolResponse(TCPResponse&& response, std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now): d_response(std::move(response)), d_state(state), d_now(now)
+  {
+  }
+
+  TCPResponse d_response;
+  std::shared_ptr<IncomingTCPConnectionState> d_state;
+  struct timeval d_now;
+};
+
+class TCPCrossProtocolQuerySender : public TCPQuerySender
+{
+public:
+  TCPCrossProtocolQuerySender(std::shared_ptr<IncomingTCPConnectionState>& state, int responseDescriptor): d_state(state), d_responseDesc(responseDescriptor)
+  {
+  }
+
+  bool active() const override
+  {
+    return d_state->active();
+  }
+
+  const ClientState* getClientState() const override
+  {
+    return d_state->getClientState();
+  }
+
+  void handleResponse(const struct timeval& now, TCPResponse&& response) override
+  {
+    if (d_responseDesc == -1) {
+      throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender");
+    }
+
+    auto ptr = new TCPCrossProtocolResponse(std::move(response), d_state, now);
+    static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
+    ssize_t sent = write(d_responseDesc, &ptr, sizeof(ptr));
+    if (sent != sizeof(ptr)) {
+      if (errno == EAGAIN || errno == EWOULDBLOCK) {
+        vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full");
+      }
+      else {
+        vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror());
+      }
+      delete ptr;
+    }
+  }
+
+  void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
+  {
+    handleResponse(now, std::move(response));
+  }
+
+  void notifyIOError(IDState&& query, const struct timeval& now) override
+  {
+    TCPResponse response(PacketBuffer(), std::move(query), nullptr);
+    handleResponse(now, std::move(response));
+  }
+
+private:
+  std::shared_ptr<IncomingTCPConnectionState> d_state;
+  int d_responseDesc{-1};
+};
+
 class TCPCrossProtocolQuery : public CrossProtocolQuery
 {
 public:
-  TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr<DownstreamState>& ds, std::shared_ptr<IncomingTCPConnectionState>& sender): d_sender(sender)
+  TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr<DownstreamState>& ds, std::shared_ptr<TCPCrossProtocolQuerySender>& sender): d_sender(sender)
   {
     query = InternalQuery(std::move(buffer), std::move(ids));
     downstream = ds;
@@ -492,7 +566,7 @@ public:
   }
 
 private:
-  std::shared_ptr<IncomingTCPConnectionState> d_sender;
+  std::shared_ptr<TCPCrossProtocolQuerySender> d_sender;
 };
 
 static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
@@ -621,8 +695,8 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
   ++state->d_currentQueriesCount;
 
   if (ds->isDoH()) {
-    std::shared_ptr<TCPQuerySender> incoming = state;
-    auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, state);
+    auto incoming = std::make_shared<TCPCrossProtocolQuerySender>(state, state->d_threadData.crossProtocolResponsesPipe);
+    auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, incoming);
 
     ds->passCrossProtocolQuery(std::move(cpq));
     return;
@@ -630,7 +704,6 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
 
   prependSizeToTCPQuery(state->d_buffer, 0);
 
-#warning FIXME: handle DoH backends here
   auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now);
 
   bool proxyProtocolPayloadAdded = false;
@@ -1080,7 +1153,40 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par
   }
 }
 
-static void tcpClientThread(int pipefd, int crossProtocolPipeFD)
+static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& param)
+{
+  TCPCrossProtocolResponse* tmp{nullptr};
+
+  ssize_t got = read(pipefd, &tmp, sizeof(tmp));
+  if (got == 0) {
+    throw std::runtime_error("EOF while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+  }
+  else if (got == -1) {
+    if (errno == EAGAIN || errno == EINTR) {
+      return;
+    }
+    throw std::runtime_error("Error while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror());
+  }
+  else if (got != sizeof(tmp)) {
+    throw std::runtime_error("Partial read while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+  }
+
+  auto response = std::move(*tmp);
+  delete tmp;
+  tmp = nullptr;
+
+  if (response.d_response.d_buffer.empty()) {
+    response.d_state->notifyIOError(std::move(response.d_response.d_idstate), response.d_now);
+  }
+  else if (response.d_response.d_idstate.qtype == QType::AXFR || response.d_response.d_idstate.qtype == QType::IXFR) {
+    response.d_state->handleXFRResponse(response.d_now, std::move(response.d_response));
+  }
+  else {
+    response.d_state->handleXFRResponse(response.d_now, std::move(response.d_response));
+  }
+}
+
+static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int crossProtocolResponsesListenPipeFD, int crossProtocolResponsesWritePipeFD)
 {
   /* we get launched with a pipe on which we receive file descriptors from clients that we own
      from that point on */
@@ -1088,9 +1194,11 @@ static void tcpClientThread(int pipefd, int crossProtocolPipeFD)
   setThreadName("dnsdist/tcpClie");
 
   TCPClientThreadData data;
-
+  /* this is the writing end! */
+  data.crossProtocolResponsesPipe = crossProtocolResponsesWritePipeFD;
   data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data);
-  data.mplexer->addReadFD(crossProtocolPipeFD, handleCrossProtocolQuery, &data);
+  data.mplexer->addReadFD(crossProtocolQueriesPipeFD, handleCrossProtocolQuery, &data);
+  data.mplexer->addReadFD(crossProtocolResponsesListenPipeFD, handleCrossProtocolResponse, &data);
 
   struct timeval now;
   gettimeofday(&now, nullptr);
index 57d5322b04c9282e5f81cf1baf5cfd13884e8f18..dbcbbcc4c8b8cb79ae6f93913ed77229c8fe4a1a 100644 (file)
@@ -1312,7 +1312,7 @@ public:
     return true;
   }
 
-  const ClientState* getClientState() override
+  const ClientState* getClientState() const override
   {
     return &d_cs;
   }
@@ -2549,6 +2549,9 @@ int main(int argc, char** argv)
       g_maxTCPClientThreads = 1;
     }
 
+    /* we need to create the TCP worker threads before the
+       acceptor ones, otherwise we might crash when processing
+       the first TCP query */
     g_tcpclientthreads = std::make_unique<TCPClientCollection>(*g_maxTCPClientThreads);
 
     initDoHWorkers();
@@ -2590,13 +2593,6 @@ int main(int argc, char** argv)
     }
     handleQueuedHealthChecks(*mplexer, true);
 
-    /* we need to create the TCP worker threads before the
-       acceptor ones, otherwise we might crash when processing
-       the first TCP query */
-    while (!g_tcpclientthreads->hasReachedMaxThreads()) {
-      g_tcpclientthreads->addTCPClientThread();
-    }
-
     for(auto& cs : g_frontends) {
       if (cs->dohFrontend != nullptr) {
 #ifdef HAVE_DNS_OVER_HTTPS
index b8c280ac63b21225e03586ca0204550db4f6da29..7ae87d9c19ba3110932a1ba3a10ae2a455aaaf15 100644 (file)
@@ -196,7 +196,7 @@ public:
     return true;
   }
 
-  const ClientState* getClientState() override
+  const ClientState* getClientState() const override
   {
     return nullptr;
   }
@@ -398,7 +398,7 @@ bool queueHealthCheck(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared
     else if (ds->isDoH()) {
       InternalQuery query(std::move(packet), IDState());
       auto sender = std::shared_ptr<TCPQuerySender>(new HealthCheckQuerySender(data));
-      if (!sendH2Query(ds, mplexer, sender, std::move(query))) {
+      if (!sendH2Query(ds, mplexer, sender, std::move(query), true)) {
         updateHealthCheckResult(data->d_ds, data->d_initial, false);
       }
     }
index 14ceccd8f8bba7078b586169a1cfacd72e1fe75f..2b503b95dff89de553eebc43df9e57b908630d96 100644 (file)
@@ -1128,14 +1128,14 @@ bool setupDoHClientProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx)
 #endif /* HAVE_NGHTTP2 */
 }
 
-bool sendH2Query(const std::shared_ptr<DownstreamState>& ds, std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<TCPQuerySender>& sender, InternalQuery&& query)
+bool sendH2Query(const std::shared_ptr<DownstreamState>& ds, std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<TCPQuerySender>& sender, InternalQuery&& query, bool healthCheck)
 {
 #ifdef HAVE_NGHTTP2
   struct timeval now;
   gettimeofday(&now, nullptr);
 
   auto newConnection = std::make_shared<DoHConnectionToBackend>(ds, mplexer, now);
-  newConnection->setHealthCheck(true);
+  newConnection->setHealthCheck(healthCheck);
   newConnection->queueQuery(sender, std::move(query));
   return true;
 #else /* HAVE_NGHTTP2 */
index 1f3aa23bdd8c11f93c48321549e9d020a65da104..5949cd3cdda7ef28e9c3490cb3846381aae38252 100644 (file)
@@ -68,4 +68,4 @@ bool setupDoHClientProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx);
 
 /* opens a new HTTP/2 connection to the supplied backend (attached to the supplied multiplexer), sends the query,
    waits for the response to come back or an error to occur then notifies the sender, closing the connection. */
-bool sendH2Query(const std::shared_ptr<DownstreamState>& ds, std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<TCPQuerySender>& sender, InternalQuery&& query);
+bool sendH2Query(const std::shared_ptr<DownstreamState>& ds, std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<TCPQuerySender>& sender, InternalQuery&& query, bool healthCheck);
index 498fea31e72ce660257784fa0a7c5c3d9831c932..e42c8acf74c7c86c5e1ffb2f74efb6e69d72421c 100644 (file)
@@ -13,6 +13,7 @@ public:
   LocalHolders holders;
   LocalStateHolder<vector<DNSDistResponseRuleAction> > localRespRuleActions;
   std::unique_ptr<FDMultiplexer> mplexer{nullptr};
+  int crossProtocolResponsesPipe{-1};
 };
 
 class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this<IncomingTCPConnectionState>
@@ -132,7 +133,7 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo
     return d_ioState != nullptr;
   }
 
-  const ClientState* getClientState() override
+  const ClientState* getClientState() const override
   {
     return d_ci.cs;
   }
index d891f31cddc33dbdad054f746d264381a176d386..312d0b72a46507ada04e3eada3e4b22ccc294f6d 100644 (file)
@@ -144,7 +144,7 @@ public:
   }
 
   virtual bool active() const = 0;
-  virtual const ClientState* getClientState() = 0;
+  virtual const ClientState* getClientState() const = 0;
   virtual void handleResponse(const struct timeval& now, TCPResponse&& response) = 0;
   virtual void handleXFRResponse(const struct timeval& now, TCPResponse&& response) = 0;
   virtual void notifyIOError(IDState&& query, const struct timeval& now) = 0;
@@ -192,7 +192,7 @@ public:
 
     uint64_t pos = d_pos++;
     ++d_queued;
-    return d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe;
+    return d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe.getHandle();
   }
 
   bool passConnectionToThread(std::unique_ptr<ConnectionInfo>&& conn)
@@ -202,7 +202,7 @@ public:
     }
 
     uint64_t pos = d_pos++;
-    auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe;
+    auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe.getHandle();
     auto tmp = conn.release();
 
     if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) {
@@ -221,7 +221,7 @@ public:
     }
 
     uint64_t pos = d_pos++;
-    auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_crossProtocolQueryPipe;
+    auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_crossProtocolQueriesPipe.getHandle();
     auto tmp = cpq.release();
 
     if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) {
@@ -253,62 +253,30 @@ public:
     --d_queued;
   }
 
+private:
   void addTCPClientThread();
 
-private:
   struct TCPWorkerThread
   {
     TCPWorkerThread()
     {
     }
 
-    TCPWorkerThread(int newConnPipe, int crossProtocolPipe) :
-      d_newConnectionPipe(newConnPipe), d_crossProtocolQueryPipe(crossProtocolPipe)
-    {
-    }
-
-    TCPWorkerThread(TCPWorkerThread&& rhs) :
-      d_newConnectionPipe(rhs.d_newConnectionPipe), d_crossProtocolQueryPipe(rhs.d_crossProtocolQueryPipe)
+    TCPWorkerThread(int newConnPipe, int crossProtocolQueriesPipe, int crossProtocolResponsesPipe) :
+      d_newConnectionPipe(newConnPipe), d_crossProtocolQueriesPipe(crossProtocolQueriesPipe), d_crossProtocolResponsesPipe(crossProtocolResponsesPipe)
     {
-      rhs.d_newConnectionPipe = -1;
-      rhs.d_crossProtocolQueryPipe = -1;
-    }
-
-    TCPWorkerThread& operator=(TCPWorkerThread&& rhs)
-    {
-      if (d_newConnectionPipe != -1) {
-        close(d_newConnectionPipe);
-      }
-      if (d_crossProtocolQueryPipe != -1) {
-        close(d_crossProtocolQueryPipe);
-      }
-
-      d_newConnectionPipe = rhs.d_newConnectionPipe;
-      d_crossProtocolQueryPipe = rhs.d_crossProtocolQueryPipe;
-      rhs.d_newConnectionPipe = -1;
-      rhs.d_crossProtocolQueryPipe = -1;
-
-      return *this;
     }
 
+    TCPWorkerThread(TCPWorkerThread&& rhs) = default;
+    TCPWorkerThread& operator=(TCPWorkerThread&& rhs) = default;
     TCPWorkerThread(const TCPWorkerThread& rhs) = delete;
     TCPWorkerThread& operator=(const TCPWorkerThread&) = delete;
 
-    ~TCPWorkerThread()
-    {
-      if (d_newConnectionPipe != -1) {
-        close(d_newConnectionPipe);
-      }
-      if (d_crossProtocolQueryPipe != -1) {
-        close(d_crossProtocolQueryPipe);
-      }
-    }
-
-    int d_newConnectionPipe{-1};
-    int d_crossProtocolQueryPipe{-1};
+    FDWrapper d_newConnectionPipe;
+    FDWrapper d_crossProtocolQueriesPipe;
+    FDWrapper d_crossProtocolResponsesPipe;
   };
 
-  std::mutex d_mutex;
   std::vector<TCPWorkerThread> d_tcpclientthreads;
   stat_t d_numthreads{0};
   stat_t d_pos{0};
index ef08f48e68a12933db32e89803674d31d21a4485..fad20531a974315f889cdaf313842bc166c3efcb 100644 (file)
@@ -424,7 +424,7 @@ public:
     return true;
   }
 
-  const ClientState* getClientState() override
+  const ClientState* getClientState() const override
   {
     if (!du || !du->dsc || !du->dsc->cs) {
       throw std::runtime_error("No query associated to this DoHTCPCrossQuerySender");
index c4f54575773febdb724b7daf28a09b040ac0ab74..5eefe8054a47e869c97117e5bb7064ce644adb8f 100644 (file)
@@ -635,3 +635,45 @@ std::string makeLuaString(const std::string& in);
 
 // Used in NID and L64 records
 struct NodeOrLocatorID { uint8_t content[8]; };
+
+struct FDWrapper
+{
+  FDWrapper()
+  {
+  }
+
+  FDWrapper(int desc): d_fd(desc)
+  {
+  }
+
+  ~FDWrapper()
+  {
+    if (d_fd != -1) {
+      close(d_fd);
+      d_fd = -1;
+    }
+  }
+
+  FDWrapper(FDWrapper&& rhs): d_fd(rhs.d_fd)
+  {
+    rhs.d_fd = -1;
+  }
+
+  FDWrapper& operator=(FDWrapper&& rhs)
+  {
+    if (d_fd) {
+      close(d_fd);
+    }
+    d_fd = rhs.d_fd;
+    rhs.d_fd = -1;
+    return *this;
+  }
+
+  int getHandle() const
+  {
+    return d_fd;
+  }
+
+private:
+  int d_fd{-1};
+};
index f84d50c36409d586f46cb69b06a5464069309f56..50a669bf86882d312e90ac03141eb1c28030f19a 100644 (file)
@@ -65,6 +65,9 @@ public:
 
   Socket& operator=(Socket&& rhs)
   {
+    if (d_socket != -1) {
+      close(d_socket);
+    }
     d_socket = rhs.d_socket;
     rhs.d_socket = -1;
     d_buffer = std::move(rhs.d_buffer);
index 4647b9279414f696263a8aecfebfa45d2fffdc04..083d72ee93ff9bcbe6ece092ee1df8a802d58111 100644 (file)
@@ -263,6 +263,9 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
               (conn, _) = sock.accept()
             except ssl.SSLError:
               continue
+            except ConnectionResetError:
+              continue
+
             conn.settimeout(5.0)
             data = conn.recv(2)
             if not data:
@@ -348,6 +351,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
                 continue
             except ConnectionResetError:
               continue
+
             conn.settimeout(5.0)
             h2conn = h2.connection.H2Connection(config=config)
             h2conn.initiate_connection()