From: Remi Gacogne Date: Mon, 30 Aug 2021 15:04:01 +0000 (+0200) Subject: dnsdist: Process responses in the right thread for incoming TCP/DoT queries X-Git-Tag: dnsdist-1.7.0-alpha1~23^2~21 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=ae3b96d9c4b8268a03e3fa6b06b8a93d1df13778;p=thirdparty%2Fpdns.git dnsdist: Process responses in the right thread for incoming TCP/DoT queries --- diff --git a/pdns/bpf-filter.cc b/pdns/bpf-filter.cc index 59912dbeb1..f6fdceebba 100644 --- a/pdns/bpf-filter.cc +++ b/pdns/bpf-filter.cc @@ -142,23 +142,23 @@ struct QNameValue BPFFilter::BPFFilter(uint32_t maxV4Addresses, uint32_t maxV6Addresses, uint32_t maxQNames): d_maps(Maps()), d_maxV4(maxV4Addresses), d_maxV6(maxV6Addresses), d_maxQNames(maxQNames) { auto maps = d_maps.lock(); - maps->d_v4map.fd = bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(uint32_t), sizeof(uint64_t), (int) maxV4Addresses); - if (maps->d_v4map.fd == -1) { + maps->d_v4map = FDWrapper(bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(uint32_t), sizeof(uint64_t), (int) maxV4Addresses)); + if (maps->d_v4map.getHandle() == -1) { throw std::runtime_error("Error creating a BPF v4 map of size " + std::to_string(maxV4Addresses) + ": " + stringerror()); } - maps->d_v6map.fd = bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(struct KeyV6), sizeof(uint64_t), (int) maxV6Addresses); - if (maps->d_v6map.fd == -1) { + maps->d_v6map = FDWrapper(bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(struct KeyV6), sizeof(uint64_t), (int) maxV6Addresses)); + if (maps->d_v6map.getHandle() == -1) { throw std::runtime_error("Error creating a BPF v6 map of size " + std::to_string(maxV6Addresses) + ": " + stringerror()); } - maps->d_qnamemap.fd = bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(struct QNameKey), sizeof(struct QNameValue), (int) maxQNames); - if (maps->d_qnamemap.fd == -1) { + maps->d_qnamemap = FDWrapper(bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(struct QNameKey), sizeof(struct QNameValue), (int) maxQNames)); + if (maps->d_qnamemap.getHandle() == -1) { throw std::runtime_error("Error creating a BPF qname map of size " + std::to_string(maxQNames) + ": " + stringerror()); } - maps->d_filtermap.fd = bpf_create_map(BPF_MAP_TYPE_PROG_ARRAY, sizeof(uint32_t), sizeof(uint32_t), 1); - if (maps->d_filtermap.fd == -1) { + maps->d_filtermap = FDWrapper(bpf_create_map(BPF_MAP_TYPE_PROG_ARRAY, sizeof(uint32_t), sizeof(uint32_t), 1)); + if (maps->d_filtermap.getHandle() == -1) { throw std::runtime_error("Error creating a BPF program map of size 1: " + stringerror()); } @@ -166,12 +166,12 @@ BPFFilter::BPFFilter(uint32_t maxV4Addresses, uint32_t maxV6Addresses, uint32_t #include "bpf-filter.main.ebpf" }; - d_mainfilter.fd = bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER, - main_filter, - sizeof(main_filter), - "GPL", - 0); - if (d_mainfilter.fd == -1) { + d_mainfilter = FDWrapper(bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER, + main_filter, + sizeof(main_filter), + "GPL", + 0)); + if (d_mainfilter.getHandle() == -1) { throw std::runtime_error("Error loading BPF main filter: " + stringerror()); } @@ -179,17 +179,18 @@ BPFFilter::BPFFilter(uint32_t maxV4Addresses, uint32_t maxV6Addresses, uint32_t #include "bpf-filter.qname.ebpf" }; - d_qnamefilter.fd = bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER, - qname_filter, - sizeof(qname_filter), - "GPL", - 0); - if (d_qnamefilter.fd == -1) { + d_qnamefilter = FDWrapper(bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER, + qname_filter, + sizeof(qname_filter), + "GPL", + 0)); + if (d_qnamefilter.getHandle() == -1) { throw std::runtime_error("Error loading BPF qname filter: " + stringerror()); } uint32_t key = 0; - int res = bpf_update_elem(maps->d_filtermap.fd, &key, &d_qnamefilter.fd, BPF_ANY); + int qnamefd = d_qnamefilter.getHandle(); + int res = bpf_update_elem(maps->d_filtermap.getHandle(), &key, &qnamefd, BPF_ANY); if (res != 0) { throw std::runtime_error("Error updating BPF filters map: " + stringerror()); } @@ -197,7 +198,8 @@ BPFFilter::BPFFilter(uint32_t maxV4Addresses, uint32_t maxV6Addresses, uint32_t void BPFFilter::addSocket(int sock) { - int res = setsockopt(sock, SOL_SOCKET, SO_ATTACH_BPF, &d_mainfilter.fd, sizeof(d_mainfilter.fd)); + int fd = d_mainfilter.getHandle(); + int res = setsockopt(sock, SOL_SOCKET, SO_ATTACH_BPF, &fd, sizeof(fd)); if (res != 0) { throw std::runtime_error("Error attaching BPF filter to this socket: " + stringerror()); @@ -206,7 +208,8 @@ void BPFFilter::addSocket(int sock) void BPFFilter::removeSocket(int sock) { - int res = setsockopt(sock, SOL_SOCKET, SO_DETACH_BPF, &d_mainfilter.fd, sizeof(d_mainfilter.fd)); + int fd = d_mainfilter.getHandle(); + int res = setsockopt(sock, SOL_SOCKET, SO_DETACH_BPF, &fd, sizeof(fd)); if (res != 0) { throw std::runtime_error("Error detaching BPF filter from this socket: " + stringerror()); @@ -224,12 +227,12 @@ void BPFFilter::block(const ComboAddress& addr) throw std::runtime_error("Table full when trying to block " + addr.toString()); } - res = bpf_lookup_elem(maps->d_v4map.fd, &key, &counter); + res = bpf_lookup_elem(maps->d_v4map.getHandle(), &key, &counter); if (res != -1) { throw std::runtime_error("Trying to block an already blocked address: " + addr.toString()); } - res = bpf_update_elem(maps->d_v4map.fd, &key, &counter, BPF_NOEXIST); + res = bpf_update_elem(maps->d_v4map.getHandle(), &key, &counter, BPF_NOEXIST); if (res == 0) { maps->d_v4Count++; } @@ -246,12 +249,12 @@ void BPFFilter::block(const ComboAddress& addr) throw std::runtime_error("Table full when trying to block " + addr.toString()); } - res = bpf_lookup_elem(maps->d_v6map.fd, &key, &counter); + res = bpf_lookup_elem(maps->d_v6map.getHandle(), &key, &counter); if (res != -1) { throw std::runtime_error("Trying to block an already blocked address: " + addr.toString()); } - res = bpf_update_elem(maps->d_v6map.fd, key, &counter, BPF_NOEXIST); + res = bpf_update_elem(maps->d_v6map.getHandle(), key, &counter, BPF_NOEXIST); if (res == 0) { maps->d_v6Count++; } @@ -268,7 +271,7 @@ void BPFFilter::unblock(const ComboAddress& addr) if (addr.isIPv4()) { uint32_t key = htonl(addr.sin4.sin_addr.s_addr); auto maps = d_maps.lock(); - res = bpf_delete_elem(maps->d_v4map.fd, &key); + res = bpf_delete_elem(maps->d_v4map.getHandle(), &key); if (res == 0) { maps->d_v4Count--; } @@ -281,7 +284,7 @@ void BPFFilter::unblock(const ComboAddress& addr) } auto maps = d_maps.lock(); - res = bpf_delete_elem(maps->d_v6map.fd, key); + res = bpf_delete_elem(maps->d_v6map.getHandle(), key); if (res == 0) { maps->d_v6Count--; } @@ -313,12 +316,12 @@ void BPFFilter::block(const DNSName& qname, uint16_t qtype) throw std::runtime_error("Table full when trying to block " + qname.toLogString()); } - int res = bpf_lookup_elem(maps->d_qnamemap.fd, &key, &value); + int res = bpf_lookup_elem(maps->d_qnamemap.getHandle(), &key, &value); if (res != -1) { throw std::runtime_error("Trying to block an already blocked qname: " + qname.toLogString()); } - res = bpf_update_elem(maps->d_qnamemap.fd, &key, &value, BPF_NOEXIST); + res = bpf_update_elem(maps->d_qnamemap.getHandle(), &key, &value, BPF_NOEXIST); if (res == 0) { maps->d_qNamesCount++; } @@ -342,7 +345,7 @@ void BPFFilter::unblock(const DNSName& qname, uint16_t qtype) { auto maps = d_maps.lock(); - int res = bpf_delete_elem(maps->d_qnamemap.fd, &key); + int res = bpf_delete_elem(maps->d_qnamemap.getHandle(), &key); if (res == 0) { maps->d_qNamesCount--; } @@ -378,28 +381,28 @@ std::vector > BPFFilter::getAddrStats() memset(&v6Key, 0, sizeof(v6Key)); auto maps = d_maps.lock(); - int res = bpf_get_next_key(maps->d_v4map.fd, &v4Key, &nextV4Key); + int res = bpf_get_next_key(maps->d_v4map.getHandle(), &v4Key, &nextV4Key); while (res == 0) { v4Key = nextV4Key; - if (bpf_lookup_elem(maps->d_v4map.fd, &v4Key, &value) == 0) { + if (bpf_lookup_elem(maps->d_v4map.getHandle(), &v4Key, &value) == 0) { v4Addr.sin_addr.s_addr = ntohl(v4Key); result.push_back(make_pair(ComboAddress(&v4Addr), value)); } - res = bpf_get_next_key(maps->d_v4map.fd, &v4Key, &nextV4Key); + res = bpf_get_next_key(maps->d_v4map.getHandle(), &v4Key, &nextV4Key); } - res = bpf_get_next_key(maps->d_v6map.fd, &v6Key, &nextV6Key); + res = bpf_get_next_key(maps->d_v6map.getHandle(), &v6Key, &nextV6Key); while (res == 0) { - if (bpf_lookup_elem(maps->d_v6map.fd, &nextV6Key, &value) == 0) { + if (bpf_lookup_elem(maps->d_v6map.getHandle(), &nextV6Key, &value) == 0) { memcpy(&v6Addr.sin6_addr.s6_addr, &nextV6Key, sizeof(nextV6Key)); result.push_back(make_pair(ComboAddress(&v6Addr), value)); } - res = bpf_get_next_key(maps->d_v6map.fd, &nextV6Key, &nextV6Key); + res = bpf_get_next_key(maps->d_v6map.getHandle(), &nextV6Key, &nextV6Key); } return result; } @@ -414,15 +417,15 @@ std::vector > BPFFilter::getQNameStats() auto maps = d_maps.lock(); result.reserve(maps->d_qNamesCount); - int res = bpf_get_next_key(maps->d_qnamemap.fd, &key, &nextKey); + int res = bpf_get_next_key(maps->d_qnamemap.getHandle(), &key, &nextKey); while (res == 0) { - if (bpf_lookup_elem(maps->d_qnamemap.fd, &nextKey, &value) == 0) { + if (bpf_lookup_elem(maps->d_qnamemap.getHandle(), &nextKey, &value) == 0) { nextKey.qname[sizeof(nextKey.qname) - 1 ] = '\0'; result.push_back(std::make_tuple(DNSName((const char*) nextKey.qname, sizeof(nextKey.qname), 0, false), value.qtype, value.counter)); } - res = bpf_get_next_key(maps->d_qnamemap.fd, &nextKey, &nextKey); + res = bpf_get_next_key(maps->d_qnamemap.getHandle(), &nextKey, &nextKey); } return result; } @@ -434,7 +437,7 @@ uint64_t BPFFilter::getHits(const ComboAddress& requestor) uint32_t key = htonl(requestor.sin4.sin_addr.s_addr); auto maps = d_maps.lock(); - int res = bpf_lookup_elem(maps->d_v4map.fd, &key, &counter); + int res = bpf_lookup_elem(maps->d_v4map.getHandle(), &key, &counter); if (res == 0) { return counter; } @@ -447,7 +450,7 @@ uint64_t BPFFilter::getHits(const ComboAddress& requestor) } auto maps = d_maps.lock(); - int res = bpf_lookup_elem(maps->d_v6map.fd, &key, &counter); + int res = bpf_lookup_elem(maps->d_v6map.getHandle(), &key, &counter); if (res == 0) { return counter; } diff --git a/pdns/bpf-filter.hh b/pdns/bpf-filter.hh index abf805e4c2..482b61e31a 100644 --- a/pdns/bpf-filter.hh +++ b/pdns/bpf-filter.hh @@ -41,16 +41,6 @@ public: private: #ifdef HAVE_EBPF - struct FDWrapper - { - ~FDWrapper() - { - if (fd != -1) { - close(fd); - } - } - int fd{-1}; - }; struct Maps { FDWrapper d_v4map; diff --git a/pdns/bpf-filter.main.ebpf b/pdns/bpf-filter.main.ebpf index a8ffa37ad1..8a82f3bce4 100644 --- a/pdns/bpf-filter.main.ebpf +++ b/pdns/bpf-filter.main.ebpf @@ -6,7 +6,7 @@ BPF_JMP_IMM(BPF_JEQ,BPF_REG_1,ntohs(0x86dd),11), BPF_JMP_IMM(BPF_JNE,BPF_REG_1,ntohs(0x0800),109), BPF_LD_ABS(BPF_W,-2097126), BPF_STX_MEM(BPF_W,BPF_REG_10,BPF_REG_0,-256), -BPF_LD_MAP_FD(BPF_REG_1,maps->d_v4map.fd), +BPF_LD_MAP_FD(BPF_REG_1,maps->d_v4map.getHandle()), BPF_MOV64_REG(BPF_REG_2,BPF_REG_10), BPF_ALU64_IMM(BPF_ADD,BPF_REG_2,-256), BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_map_lookup_elem), @@ -45,7 +45,7 @@ BPF_LD_ABS(BPF_B,-2097116), BPF_STX_MEM(BPF_B,BPF_REG_10,BPF_REG_0,-242), BPF_LD_ABS(BPF_B,-2097115), BPF_STX_MEM(BPF_B,BPF_REG_10,BPF_REG_0,-241), -BPF_LD_MAP_FD(BPF_REG_1,maps->d_v6map.fd), +BPF_LD_MAP_FD(BPF_REG_1,maps->d_v6map.getHandle()), BPF_MOV64_REG(BPF_REG_2,BPF_REG_10), BPF_ALU64_IMM(BPF_ADD,BPF_REG_2,-256), BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_map_lookup_elem), @@ -97,7 +97,7 @@ BPF_JMP_IMM(BPF_JGT,BPF_REG_8,63,17), BPF_JMP_IMM(BPF_JNE,BPF_REG_8,0,18), BPF_LD_ABS(BPF_H,21), BPF_MOV64_REG(BPF_REG_6,BPF_REG_0), -BPF_LD_MAP_FD(BPF_REG_1,maps->d_qnamemap.fd), +BPF_LD_MAP_FD(BPF_REG_1,maps->d_qnamemap.getHandle()), BPF_MOV64_REG(BPF_REG_2,BPF_REG_10), BPF_ALU64_IMM(BPF_ADD,BPF_REG_2,-256), BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_map_lookup_elem), @@ -128,7 +128,7 @@ BPF_ALU64_IMM(BPF_ADD,BPF_REG_8,-1), BPF_STX_MEM(BPF_W,BPF_REG_6,BPF_REG_8,60), BPF_ALU64_IMM(BPF_AND,BPF_REG_1,255), BPF_STX_MEM(BPF_W,BPF_REG_6,BPF_REG_1,56), -BPF_LD_MAP_FD(BPF_REG_2,maps->d_filtermap.fd), +BPF_LD_MAP_FD(BPF_REG_2,maps->d_filtermap.getHandle()), BPF_MOV64_REG(BPF_REG_1,BPF_REG_6), BPF_MOV64_IMM(BPF_REG_3,0), BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_tail_call), diff --git a/pdns/bpf-filter.qname.ebpf b/pdns/bpf-filter.qname.ebpf index febb7b862c..7eb0519bd1 100644 --- a/pdns/bpf-filter.qname.ebpf +++ b/pdns/bpf-filter.qname.ebpf @@ -4078,7 +4078,7 @@ BPF_MOV64_IMM(BPF_REG_9,255), BPF_ALU64_REG(BPF_ADD,BPF_REG_9,BPF_REG_7), BPF_RAW_INSN(BPF_LD|BPF_IND|BPF_H,BPF_REG_0,BPF_REG_9,0,0), BPF_MOV64_REG(BPF_REG_6,BPF_REG_0), -BPF_LD_MAP_FD(BPF_REG_1,maps->d_qnamemap.fd), +BPF_LD_MAP_FD(BPF_REG_1,maps->d_qnamemap.getHandle()), BPF_MOV64_REG(BPF_REG_2,BPF_REG_10), BPF_ALU64_IMM(BPF_ADD,BPF_REG_2,-256), BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_map_lookup_elem), diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 1b478530c9..71dac2f47c 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -124,10 +124,13 @@ std::shared_ptr IncomingTCPConnectionState::getDownstrea return downstream; } -static void tcpClientThread(int pipefd, int crossProtocolPipeFD); +static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int crossProtocolResponsesListenPipeFD, int crossProtocolResponsesWritePipeFD); TCPClientCollection::TCPClientCollection(size_t maxThreads): d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads) { + for (size_t idx = 0; idx < maxThreads; idx++) { + addTCPClientThread(); + } } void TCPClientCollection::addTCPClientThread() @@ -166,20 +169,25 @@ void TCPClientCollection::addTCPClientThread() return; } - int crossProtocolFDs[2] = { -1, -1}; - if (!preparePipe(crossProtocolFDs, "cross-protocol")) { + int crossProtocolQueriesFDs[2] = { -1, -1}; + if (!preparePipe(crossProtocolQueriesFDs, "cross-protocol queries")) { + return; + } + + int crossProtocolResponsesFDs[2] = { -1, -1}; + if (!preparePipe(crossProtocolResponsesFDs, "cross-protocol responses")) { return; } vinfolog("Adding TCP Client thread"); { - std::lock_guard lock(d_mutex); - if (d_numthreads >= d_tcpclientthreads.size()) { vinfolog("Adding a new TCP client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of TCP client threads with setMaxTCPClientThreads() in the configuration.", d_numthreads.load(), d_tcpclientthreads.size()); - close(crossProtocolFDs[0]); - close(crossProtocolFDs[1]); + close(crossProtocolQueriesFDs[0]); + close(crossProtocolQueriesFDs[1]); + close(crossProtocolResponsesFDs[0]); + close(crossProtocolResponsesFDs[1]); close(pipefds[0]); close(pipefds[1]); return; @@ -187,15 +195,17 @@ void TCPClientCollection::addTCPClientThread() /* from now on this side of the pipe will be managed by that object, no need to worry about it */ - TCPWorkerThread worker(pipefds[1], crossProtocolFDs[1]); + TCPWorkerThread worker(pipefds[1], crossProtocolQueriesFDs[1], crossProtocolResponsesFDs[1]); try { - std::thread t1(tcpClientThread, pipefds[0], crossProtocolFDs[0]); + std::thread t1(tcpClientThread, pipefds[0], crossProtocolQueriesFDs[0], crossProtocolResponsesFDs[0], crossProtocolResponsesFDs[1]); t1.detach(); } catch (const std::runtime_error& e) { /* the thread creation failed, don't leak */ errlog("Error creating a TCP thread: %s", e.what()); close(pipefds[0]); + close(crossProtocolQueriesFDs[0]); + close(crossProtocolResponsesFDs[0]); return; } @@ -471,10 +481,74 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe queueResponse(state, now, std::move(response)); } +struct TCPCrossProtocolResponse +{ + TCPCrossProtocolResponse(TCPResponse&& response, std::shared_ptr& state, const struct timeval& now): d_response(std::move(response)), d_state(state), d_now(now) + { + } + + TCPResponse d_response; + std::shared_ptr d_state; + struct timeval d_now; +}; + +class TCPCrossProtocolQuerySender : public TCPQuerySender +{ +public: + TCPCrossProtocolQuerySender(std::shared_ptr& state, int responseDescriptor): d_state(state), d_responseDesc(responseDescriptor) + { + } + + bool active() const override + { + return d_state->active(); + } + + const ClientState* getClientState() const override + { + return d_state->getClientState(); + } + + void handleResponse(const struct timeval& now, TCPResponse&& response) override + { + if (d_responseDesc == -1) { + throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender"); + } + + auto ptr = new TCPCrossProtocolResponse(std::move(response), d_state, now); + static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); + ssize_t sent = write(d_responseDesc, &ptr, sizeof(ptr)); + if (sent != sizeof(ptr)) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full"); + } + else { + vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror()); + } + delete ptr; + } + } + + void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override + { + handleResponse(now, std::move(response)); + } + + void notifyIOError(IDState&& query, const struct timeval& now) override + { + TCPResponse response(PacketBuffer(), std::move(query), nullptr); + handleResponse(now, std::move(response)); + } + +private: + std::shared_ptr d_state; + int d_responseDesc{-1}; +}; + class TCPCrossProtocolQuery : public CrossProtocolQuery { public: - TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr& ds, std::shared_ptr& sender): d_sender(sender) + TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr& ds, std::shared_ptr& sender): d_sender(sender) { query = InternalQuery(std::move(buffer), std::move(ids)); downstream = ds; @@ -492,7 +566,7 @@ public: } private: - std::shared_ptr d_sender; + std::shared_ptr d_sender; }; static void handleQuery(std::shared_ptr& state, const struct timeval& now) @@ -621,8 +695,8 @@ static void handleQuery(std::shared_ptr& state, cons ++state->d_currentQueriesCount; if (ds->isDoH()) { - std::shared_ptr incoming = state; - auto cpq = std::make_unique(std::move(state->d_buffer), std::move(ids), ds, state); + auto incoming = std::make_shared(state, state->d_threadData.crossProtocolResponsesPipe); + auto cpq = std::make_unique(std::move(state->d_buffer), std::move(ids), ds, incoming); ds->passCrossProtocolQuery(std::move(cpq)); return; @@ -630,7 +704,6 @@ static void handleQuery(std::shared_ptr& state, cons prependSizeToTCPQuery(state->d_buffer, 0); -#warning FIXME: handle DoH backends here auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now); bool proxyProtocolPayloadAdded = false; @@ -1080,7 +1153,40 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par } } -static void tcpClientThread(int pipefd, int crossProtocolPipeFD) +static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& param) +{ + TCPCrossProtocolResponse* tmp{nullptr}; + + ssize_t got = read(pipefd, &tmp, sizeof(tmp)); + if (got == 0) { + throw std::runtime_error("EOF while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + } + else if (got == -1) { + if (errno == EAGAIN || errno == EINTR) { + return; + } + throw std::runtime_error("Error while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror()); + } + else if (got != sizeof(tmp)) { + throw std::runtime_error("Partial read while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + } + + auto response = std::move(*tmp); + delete tmp; + tmp = nullptr; + + if (response.d_response.d_buffer.empty()) { + response.d_state->notifyIOError(std::move(response.d_response.d_idstate), response.d_now); + } + else if (response.d_response.d_idstate.qtype == QType::AXFR || response.d_response.d_idstate.qtype == QType::IXFR) { + response.d_state->handleXFRResponse(response.d_now, std::move(response.d_response)); + } + else { + response.d_state->handleXFRResponse(response.d_now, std::move(response.d_response)); + } +} + +static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int crossProtocolResponsesListenPipeFD, int crossProtocolResponsesWritePipeFD) { /* we get launched with a pipe on which we receive file descriptors from clients that we own from that point on */ @@ -1088,9 +1194,11 @@ static void tcpClientThread(int pipefd, int crossProtocolPipeFD) setThreadName("dnsdist/tcpClie"); TCPClientThreadData data; - + /* this is the writing end! */ + data.crossProtocolResponsesPipe = crossProtocolResponsesWritePipeFD; data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data); - data.mplexer->addReadFD(crossProtocolPipeFD, handleCrossProtocolQuery, &data); + data.mplexer->addReadFD(crossProtocolQueriesPipeFD, handleCrossProtocolQuery, &data); + data.mplexer->addReadFD(crossProtocolResponsesListenPipeFD, handleCrossProtocolResponse, &data); struct timeval now; gettimeofday(&now, nullptr); diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 57d5322b04..dbcbbcc4c8 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -1312,7 +1312,7 @@ public: return true; } - const ClientState* getClientState() override + const ClientState* getClientState() const override { return &d_cs; } @@ -2549,6 +2549,9 @@ int main(int argc, char** argv) g_maxTCPClientThreads = 1; } + /* we need to create the TCP worker threads before the + acceptor ones, otherwise we might crash when processing + the first TCP query */ g_tcpclientthreads = std::make_unique(*g_maxTCPClientThreads); initDoHWorkers(); @@ -2590,13 +2593,6 @@ int main(int argc, char** argv) } handleQueuedHealthChecks(*mplexer, true); - /* we need to create the TCP worker threads before the - acceptor ones, otherwise we might crash when processing - the first TCP query */ - while (!g_tcpclientthreads->hasReachedMaxThreads()) { - g_tcpclientthreads->addTCPClientThread(); - } - for(auto& cs : g_frontends) { if (cs->dohFrontend != nullptr) { #ifdef HAVE_DNS_OVER_HTTPS diff --git a/pdns/dnsdistdist/dnsdist-healthchecks.cc b/pdns/dnsdistdist/dnsdist-healthchecks.cc index b8c280ac63..7ae87d9c19 100644 --- a/pdns/dnsdistdist/dnsdist-healthchecks.cc +++ b/pdns/dnsdistdist/dnsdist-healthchecks.cc @@ -196,7 +196,7 @@ public: return true; } - const ClientState* getClientState() override + const ClientState* getClientState() const override { return nullptr; } @@ -398,7 +398,7 @@ bool queueHealthCheck(std::unique_ptr& mplexer, const std::shared else if (ds->isDoH()) { InternalQuery query(std::move(packet), IDState()); auto sender = std::shared_ptr(new HealthCheckQuerySender(data)); - if (!sendH2Query(ds, mplexer, sender, std::move(query))) { + if (!sendH2Query(ds, mplexer, sender, std::move(query), true)) { updateHealthCheckResult(data->d_ds, data->d_initial, false); } } diff --git a/pdns/dnsdistdist/dnsdist-nghttp2.cc b/pdns/dnsdistdist/dnsdist-nghttp2.cc index 14ceccd8f8..2b503b95df 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2.cc +++ b/pdns/dnsdistdist/dnsdist-nghttp2.cc @@ -1128,14 +1128,14 @@ bool setupDoHClientProtocolNegotiation(std::shared_ptr& ctx) #endif /* HAVE_NGHTTP2 */ } -bool sendH2Query(const std::shared_ptr& ds, std::unique_ptr& mplexer, std::shared_ptr& sender, InternalQuery&& query) +bool sendH2Query(const std::shared_ptr& ds, std::unique_ptr& mplexer, std::shared_ptr& sender, InternalQuery&& query, bool healthCheck) { #ifdef HAVE_NGHTTP2 struct timeval now; gettimeofday(&now, nullptr); auto newConnection = std::make_shared(ds, mplexer, now); - newConnection->setHealthCheck(true); + newConnection->setHealthCheck(healthCheck); newConnection->queueQuery(sender, std::move(query)); return true; #else /* HAVE_NGHTTP2 */ diff --git a/pdns/dnsdistdist/dnsdist-nghttp2.hh b/pdns/dnsdistdist/dnsdist-nghttp2.hh index 1f3aa23bdd..5949cd3cdd 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2.hh +++ b/pdns/dnsdistdist/dnsdist-nghttp2.hh @@ -68,4 +68,4 @@ bool setupDoHClientProtocolNegotiation(std::shared_ptr& ctx); /* opens a new HTTP/2 connection to the supplied backend (attached to the supplied multiplexer), sends the query, waits for the response to come back or an error to occur then notifies the sender, closing the connection. */ -bool sendH2Query(const std::shared_ptr& ds, std::unique_ptr& mplexer, std::shared_ptr& sender, InternalQuery&& query); +bool sendH2Query(const std::shared_ptr& ds, std::unique_ptr& mplexer, std::shared_ptr& sender, InternalQuery&& query, bool healthCheck); diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh index 498fea31e7..e42c8acf74 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh @@ -13,6 +13,7 @@ public: LocalHolders holders; LocalStateHolder > localRespRuleActions; std::unique_ptr mplexer{nullptr}; + int crossProtocolResponsesPipe{-1}; }; class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this @@ -132,7 +133,7 @@ static void handleTimeout(std::shared_ptr& state, bo return d_ioState != nullptr; } - const ClientState* getClientState() override + const ClientState* getClientState() const override { return d_ci.cs; } diff --git a/pdns/dnsdistdist/dnsdist-tcp.hh b/pdns/dnsdistdist/dnsdist-tcp.hh index d891f31cdd..312d0b72a4 100644 --- a/pdns/dnsdistdist/dnsdist-tcp.hh +++ b/pdns/dnsdistdist/dnsdist-tcp.hh @@ -144,7 +144,7 @@ public: } virtual bool active() const = 0; - virtual const ClientState* getClientState() = 0; + virtual const ClientState* getClientState() const = 0; virtual void handleResponse(const struct timeval& now, TCPResponse&& response) = 0; virtual void handleXFRResponse(const struct timeval& now, TCPResponse&& response) = 0; virtual void notifyIOError(IDState&& query, const struct timeval& now) = 0; @@ -192,7 +192,7 @@ public: uint64_t pos = d_pos++; ++d_queued; - return d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe; + return d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe.getHandle(); } bool passConnectionToThread(std::unique_ptr&& conn) @@ -202,7 +202,7 @@ public: } uint64_t pos = d_pos++; - auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe; + auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe.getHandle(); auto tmp = conn.release(); if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) { @@ -221,7 +221,7 @@ public: } uint64_t pos = d_pos++; - auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_crossProtocolQueryPipe; + auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_crossProtocolQueriesPipe.getHandle(); auto tmp = cpq.release(); if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) { @@ -253,62 +253,30 @@ public: --d_queued; } +private: void addTCPClientThread(); -private: struct TCPWorkerThread { TCPWorkerThread() { } - TCPWorkerThread(int newConnPipe, int crossProtocolPipe) : - d_newConnectionPipe(newConnPipe), d_crossProtocolQueryPipe(crossProtocolPipe) - { - } - - TCPWorkerThread(TCPWorkerThread&& rhs) : - d_newConnectionPipe(rhs.d_newConnectionPipe), d_crossProtocolQueryPipe(rhs.d_crossProtocolQueryPipe) + TCPWorkerThread(int newConnPipe, int crossProtocolQueriesPipe, int crossProtocolResponsesPipe) : + d_newConnectionPipe(newConnPipe), d_crossProtocolQueriesPipe(crossProtocolQueriesPipe), d_crossProtocolResponsesPipe(crossProtocolResponsesPipe) { - rhs.d_newConnectionPipe = -1; - rhs.d_crossProtocolQueryPipe = -1; - } - - TCPWorkerThread& operator=(TCPWorkerThread&& rhs) - { - if (d_newConnectionPipe != -1) { - close(d_newConnectionPipe); - } - if (d_crossProtocolQueryPipe != -1) { - close(d_crossProtocolQueryPipe); - } - - d_newConnectionPipe = rhs.d_newConnectionPipe; - d_crossProtocolQueryPipe = rhs.d_crossProtocolQueryPipe; - rhs.d_newConnectionPipe = -1; - rhs.d_crossProtocolQueryPipe = -1; - - return *this; } + TCPWorkerThread(TCPWorkerThread&& rhs) = default; + TCPWorkerThread& operator=(TCPWorkerThread&& rhs) = default; TCPWorkerThread(const TCPWorkerThread& rhs) = delete; TCPWorkerThread& operator=(const TCPWorkerThread&) = delete; - ~TCPWorkerThread() - { - if (d_newConnectionPipe != -1) { - close(d_newConnectionPipe); - } - if (d_crossProtocolQueryPipe != -1) { - close(d_crossProtocolQueryPipe); - } - } - - int d_newConnectionPipe{-1}; - int d_crossProtocolQueryPipe{-1}; + FDWrapper d_newConnectionPipe; + FDWrapper d_crossProtocolQueriesPipe; + FDWrapper d_crossProtocolResponsesPipe; }; - std::mutex d_mutex; std::vector d_tcpclientthreads; stat_t d_numthreads{0}; stat_t d_pos{0}; diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index ef08f48e68..fad20531a9 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -424,7 +424,7 @@ public: return true; } - const ClientState* getClientState() override + const ClientState* getClientState() const override { if (!du || !du->dsc || !du->dsc->cs) { throw std::runtime_error("No query associated to this DoHTCPCrossQuerySender"); diff --git a/pdns/misc.hh b/pdns/misc.hh index c4f5457577..5eefe8054a 100644 --- a/pdns/misc.hh +++ b/pdns/misc.hh @@ -635,3 +635,45 @@ std::string makeLuaString(const std::string& in); // Used in NID and L64 records struct NodeOrLocatorID { uint8_t content[8]; }; + +struct FDWrapper +{ + FDWrapper() + { + } + + FDWrapper(int desc): d_fd(desc) + { + } + + ~FDWrapper() + { + if (d_fd != -1) { + close(d_fd); + d_fd = -1; + } + } + + FDWrapper(FDWrapper&& rhs): d_fd(rhs.d_fd) + { + rhs.d_fd = -1; + } + + FDWrapper& operator=(FDWrapper&& rhs) + { + if (d_fd) { + close(d_fd); + } + d_fd = rhs.d_fd; + rhs.d_fd = -1; + return *this; + } + + int getHandle() const + { + return d_fd; + } + +private: + int d_fd{-1}; +}; diff --git a/pdns/sstuff.hh b/pdns/sstuff.hh index f84d50c364..50a669bf86 100644 --- a/pdns/sstuff.hh +++ b/pdns/sstuff.hh @@ -65,6 +65,9 @@ public: Socket& operator=(Socket&& rhs) { + if (d_socket != -1) { + close(d_socket); + } d_socket = rhs.d_socket; rhs.d_socket = -1; d_buffer = std::move(rhs.d_buffer); diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 4647b92794..083d72ee93 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -263,6 +263,9 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): (conn, _) = sock.accept() except ssl.SSLError: continue + except ConnectionResetError: + continue + conn.settimeout(5.0) data = conn.recv(2) if not data: @@ -348,6 +351,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): continue except ConnectionResetError: continue + conn.settimeout(5.0) h2conn = h2.connection.H2Connection(config=config) h2conn.initiate_connection()