BPFFilter::BPFFilter(uint32_t maxV4Addresses, uint32_t maxV6Addresses, uint32_t maxQNames): d_maps(Maps()), d_maxV4(maxV4Addresses), d_maxV6(maxV6Addresses), d_maxQNames(maxQNames)
{
auto maps = d_maps.lock();
- maps->d_v4map.fd = bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(uint32_t), sizeof(uint64_t), (int) maxV4Addresses);
- if (maps->d_v4map.fd == -1) {
+ maps->d_v4map = FDWrapper(bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(uint32_t), sizeof(uint64_t), (int) maxV4Addresses));
+ if (maps->d_v4map.getHandle() == -1) {
throw std::runtime_error("Error creating a BPF v4 map of size " + std::to_string(maxV4Addresses) + ": " + stringerror());
}
- maps->d_v6map.fd = bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(struct KeyV6), sizeof(uint64_t), (int) maxV6Addresses);
- if (maps->d_v6map.fd == -1) {
+ maps->d_v6map = FDWrapper(bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(struct KeyV6), sizeof(uint64_t), (int) maxV6Addresses));
+ if (maps->d_v6map.getHandle() == -1) {
throw std::runtime_error("Error creating a BPF v6 map of size " + std::to_string(maxV6Addresses) + ": " + stringerror());
}
- maps->d_qnamemap.fd = bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(struct QNameKey), sizeof(struct QNameValue), (int) maxQNames);
- if (maps->d_qnamemap.fd == -1) {
+ maps->d_qnamemap = FDWrapper(bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(struct QNameKey), sizeof(struct QNameValue), (int) maxQNames));
+ if (maps->d_qnamemap.getHandle() == -1) {
throw std::runtime_error("Error creating a BPF qname map of size " + std::to_string(maxQNames) + ": " + stringerror());
}
- maps->d_filtermap.fd = bpf_create_map(BPF_MAP_TYPE_PROG_ARRAY, sizeof(uint32_t), sizeof(uint32_t), 1);
- if (maps->d_filtermap.fd == -1) {
+ maps->d_filtermap = FDWrapper(bpf_create_map(BPF_MAP_TYPE_PROG_ARRAY, sizeof(uint32_t), sizeof(uint32_t), 1));
+ if (maps->d_filtermap.getHandle() == -1) {
throw std::runtime_error("Error creating a BPF program map of size 1: " + stringerror());
}
#include "bpf-filter.main.ebpf"
};
- d_mainfilter.fd = bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER,
- main_filter,
- sizeof(main_filter),
- "GPL",
- 0);
- if (d_mainfilter.fd == -1) {
+ d_mainfilter = FDWrapper(bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER,
+ main_filter,
+ sizeof(main_filter),
+ "GPL",
+ 0));
+ if (d_mainfilter.getHandle() == -1) {
throw std::runtime_error("Error loading BPF main filter: " + stringerror());
}
#include "bpf-filter.qname.ebpf"
};
- d_qnamefilter.fd = bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER,
- qname_filter,
- sizeof(qname_filter),
- "GPL",
- 0);
- if (d_qnamefilter.fd == -1) {
+ d_qnamefilter = FDWrapper(bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER,
+ qname_filter,
+ sizeof(qname_filter),
+ "GPL",
+ 0));
+ if (d_qnamefilter.getHandle() == -1) {
throw std::runtime_error("Error loading BPF qname filter: " + stringerror());
}
uint32_t key = 0;
- int res = bpf_update_elem(maps->d_filtermap.fd, &key, &d_qnamefilter.fd, BPF_ANY);
+ int qnamefd = d_qnamefilter.getHandle();
+ int res = bpf_update_elem(maps->d_filtermap.getHandle(), &key, &qnamefd, BPF_ANY);
if (res != 0) {
throw std::runtime_error("Error updating BPF filters map: " + stringerror());
}
void BPFFilter::addSocket(int sock)
{
- int res = setsockopt(sock, SOL_SOCKET, SO_ATTACH_BPF, &d_mainfilter.fd, sizeof(d_mainfilter.fd));
+ int fd = d_mainfilter.getHandle();
+ int res = setsockopt(sock, SOL_SOCKET, SO_ATTACH_BPF, &fd, sizeof(fd));
if (res != 0) {
throw std::runtime_error("Error attaching BPF filter to this socket: " + stringerror());
void BPFFilter::removeSocket(int sock)
{
- int res = setsockopt(sock, SOL_SOCKET, SO_DETACH_BPF, &d_mainfilter.fd, sizeof(d_mainfilter.fd));
+ int fd = d_mainfilter.getHandle();
+ int res = setsockopt(sock, SOL_SOCKET, SO_DETACH_BPF, &fd, sizeof(fd));
if (res != 0) {
throw std::runtime_error("Error detaching BPF filter from this socket: " + stringerror());
throw std::runtime_error("Table full when trying to block " + addr.toString());
}
- res = bpf_lookup_elem(maps->d_v4map.fd, &key, &counter);
+ res = bpf_lookup_elem(maps->d_v4map.getHandle(), &key, &counter);
if (res != -1) {
throw std::runtime_error("Trying to block an already blocked address: " + addr.toString());
}
- res = bpf_update_elem(maps->d_v4map.fd, &key, &counter, BPF_NOEXIST);
+ res = bpf_update_elem(maps->d_v4map.getHandle(), &key, &counter, BPF_NOEXIST);
if (res == 0) {
maps->d_v4Count++;
}
throw std::runtime_error("Table full when trying to block " + addr.toString());
}
- res = bpf_lookup_elem(maps->d_v6map.fd, &key, &counter);
+ res = bpf_lookup_elem(maps->d_v6map.getHandle(), &key, &counter);
if (res != -1) {
throw std::runtime_error("Trying to block an already blocked address: " + addr.toString());
}
- res = bpf_update_elem(maps->d_v6map.fd, key, &counter, BPF_NOEXIST);
+ res = bpf_update_elem(maps->d_v6map.getHandle(), key, &counter, BPF_NOEXIST);
if (res == 0) {
maps->d_v6Count++;
}
if (addr.isIPv4()) {
uint32_t key = htonl(addr.sin4.sin_addr.s_addr);
auto maps = d_maps.lock();
- res = bpf_delete_elem(maps->d_v4map.fd, &key);
+ res = bpf_delete_elem(maps->d_v4map.getHandle(), &key);
if (res == 0) {
maps->d_v4Count--;
}
}
auto maps = d_maps.lock();
- res = bpf_delete_elem(maps->d_v6map.fd, key);
+ res = bpf_delete_elem(maps->d_v6map.getHandle(), key);
if (res == 0) {
maps->d_v6Count--;
}
throw std::runtime_error("Table full when trying to block " + qname.toLogString());
}
- int res = bpf_lookup_elem(maps->d_qnamemap.fd, &key, &value);
+ int res = bpf_lookup_elem(maps->d_qnamemap.getHandle(), &key, &value);
if (res != -1) {
throw std::runtime_error("Trying to block an already blocked qname: " + qname.toLogString());
}
- res = bpf_update_elem(maps->d_qnamemap.fd, &key, &value, BPF_NOEXIST);
+ res = bpf_update_elem(maps->d_qnamemap.getHandle(), &key, &value, BPF_NOEXIST);
if (res == 0) {
maps->d_qNamesCount++;
}
{
auto maps = d_maps.lock();
- int res = bpf_delete_elem(maps->d_qnamemap.fd, &key);
+ int res = bpf_delete_elem(maps->d_qnamemap.getHandle(), &key);
if (res == 0) {
maps->d_qNamesCount--;
}
memset(&v6Key, 0, sizeof(v6Key));
auto maps = d_maps.lock();
- int res = bpf_get_next_key(maps->d_v4map.fd, &v4Key, &nextV4Key);
+ int res = bpf_get_next_key(maps->d_v4map.getHandle(), &v4Key, &nextV4Key);
while (res == 0) {
v4Key = nextV4Key;
- if (bpf_lookup_elem(maps->d_v4map.fd, &v4Key, &value) == 0) {
+ if (bpf_lookup_elem(maps->d_v4map.getHandle(), &v4Key, &value) == 0) {
v4Addr.sin_addr.s_addr = ntohl(v4Key);
result.push_back(make_pair(ComboAddress(&v4Addr), value));
}
- res = bpf_get_next_key(maps->d_v4map.fd, &v4Key, &nextV4Key);
+ res = bpf_get_next_key(maps->d_v4map.getHandle(), &v4Key, &nextV4Key);
}
- res = bpf_get_next_key(maps->d_v6map.fd, &v6Key, &nextV6Key);
+ res = bpf_get_next_key(maps->d_v6map.getHandle(), &v6Key, &nextV6Key);
while (res == 0) {
- if (bpf_lookup_elem(maps->d_v6map.fd, &nextV6Key, &value) == 0) {
+ if (bpf_lookup_elem(maps->d_v6map.getHandle(), &nextV6Key, &value) == 0) {
memcpy(&v6Addr.sin6_addr.s6_addr, &nextV6Key, sizeof(nextV6Key));
result.push_back(make_pair(ComboAddress(&v6Addr), value));
}
- res = bpf_get_next_key(maps->d_v6map.fd, &nextV6Key, &nextV6Key);
+ res = bpf_get_next_key(maps->d_v6map.getHandle(), &nextV6Key, &nextV6Key);
}
return result;
}
auto maps = d_maps.lock();
result.reserve(maps->d_qNamesCount);
- int res = bpf_get_next_key(maps->d_qnamemap.fd, &key, &nextKey);
+ int res = bpf_get_next_key(maps->d_qnamemap.getHandle(), &key, &nextKey);
while (res == 0) {
- if (bpf_lookup_elem(maps->d_qnamemap.fd, &nextKey, &value) == 0) {
+ if (bpf_lookup_elem(maps->d_qnamemap.getHandle(), &nextKey, &value) == 0) {
nextKey.qname[sizeof(nextKey.qname) - 1 ] = '\0';
result.push_back(std::make_tuple(DNSName((const char*) nextKey.qname, sizeof(nextKey.qname), 0, false), value.qtype, value.counter));
}
- res = bpf_get_next_key(maps->d_qnamemap.fd, &nextKey, &nextKey);
+ res = bpf_get_next_key(maps->d_qnamemap.getHandle(), &nextKey, &nextKey);
}
return result;
}
uint32_t key = htonl(requestor.sin4.sin_addr.s_addr);
auto maps = d_maps.lock();
- int res = bpf_lookup_elem(maps->d_v4map.fd, &key, &counter);
+ int res = bpf_lookup_elem(maps->d_v4map.getHandle(), &key, &counter);
if (res == 0) {
return counter;
}
}
auto maps = d_maps.lock();
- int res = bpf_lookup_elem(maps->d_v6map.fd, &key, &counter);
+ int res = bpf_lookup_elem(maps->d_v6map.getHandle(), &key, &counter);
if (res == 0) {
return counter;
}
private:
#ifdef HAVE_EBPF
- struct FDWrapper
- {
- ~FDWrapper()
- {
- if (fd != -1) {
- close(fd);
- }
- }
- int fd{-1};
- };
struct Maps
{
FDWrapper d_v4map;
BPF_JMP_IMM(BPF_JNE,BPF_REG_1,ntohs(0x0800),109),
BPF_LD_ABS(BPF_W,-2097126),
BPF_STX_MEM(BPF_W,BPF_REG_10,BPF_REG_0,-256),
-BPF_LD_MAP_FD(BPF_REG_1,maps->d_v4map.fd),
+BPF_LD_MAP_FD(BPF_REG_1,maps->d_v4map.getHandle()),
BPF_MOV64_REG(BPF_REG_2,BPF_REG_10),
BPF_ALU64_IMM(BPF_ADD,BPF_REG_2,-256),
BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_map_lookup_elem),
BPF_STX_MEM(BPF_B,BPF_REG_10,BPF_REG_0,-242),
BPF_LD_ABS(BPF_B,-2097115),
BPF_STX_MEM(BPF_B,BPF_REG_10,BPF_REG_0,-241),
-BPF_LD_MAP_FD(BPF_REG_1,maps->d_v6map.fd),
+BPF_LD_MAP_FD(BPF_REG_1,maps->d_v6map.getHandle()),
BPF_MOV64_REG(BPF_REG_2,BPF_REG_10),
BPF_ALU64_IMM(BPF_ADD,BPF_REG_2,-256),
BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_map_lookup_elem),
BPF_JMP_IMM(BPF_JNE,BPF_REG_8,0,18),
BPF_LD_ABS(BPF_H,21),
BPF_MOV64_REG(BPF_REG_6,BPF_REG_0),
-BPF_LD_MAP_FD(BPF_REG_1,maps->d_qnamemap.fd),
+BPF_LD_MAP_FD(BPF_REG_1,maps->d_qnamemap.getHandle()),
BPF_MOV64_REG(BPF_REG_2,BPF_REG_10),
BPF_ALU64_IMM(BPF_ADD,BPF_REG_2,-256),
BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_map_lookup_elem),
BPF_STX_MEM(BPF_W,BPF_REG_6,BPF_REG_8,60),
BPF_ALU64_IMM(BPF_AND,BPF_REG_1,255),
BPF_STX_MEM(BPF_W,BPF_REG_6,BPF_REG_1,56),
-BPF_LD_MAP_FD(BPF_REG_2,maps->d_filtermap.fd),
+BPF_LD_MAP_FD(BPF_REG_2,maps->d_filtermap.getHandle()),
BPF_MOV64_REG(BPF_REG_1,BPF_REG_6),
BPF_MOV64_IMM(BPF_REG_3,0),
BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_tail_call),
BPF_ALU64_REG(BPF_ADD,BPF_REG_9,BPF_REG_7),
BPF_RAW_INSN(BPF_LD|BPF_IND|BPF_H,BPF_REG_0,BPF_REG_9,0,0),
BPF_MOV64_REG(BPF_REG_6,BPF_REG_0),
-BPF_LD_MAP_FD(BPF_REG_1,maps->d_qnamemap.fd),
+BPF_LD_MAP_FD(BPF_REG_1,maps->d_qnamemap.getHandle()),
BPF_MOV64_REG(BPF_REG_2,BPF_REG_10),
BPF_ALU64_IMM(BPF_ADD,BPF_REG_2,-256),
BPF_RAW_INSN(BPF_JMP|BPF_CALL,0,0,0,BPF_FUNC_map_lookup_elem),
return downstream;
}
-static void tcpClientThread(int pipefd, int crossProtocolPipeFD);
+static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int crossProtocolResponsesListenPipeFD, int crossProtocolResponsesWritePipeFD);
TCPClientCollection::TCPClientCollection(size_t maxThreads): d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads)
{
+ for (size_t idx = 0; idx < maxThreads; idx++) {
+ addTCPClientThread();
+ }
}
void TCPClientCollection::addTCPClientThread()
return;
}
- int crossProtocolFDs[2] = { -1, -1};
- if (!preparePipe(crossProtocolFDs, "cross-protocol")) {
+ int crossProtocolQueriesFDs[2] = { -1, -1};
+ if (!preparePipe(crossProtocolQueriesFDs, "cross-protocol queries")) {
+ return;
+ }
+
+ int crossProtocolResponsesFDs[2] = { -1, -1};
+ if (!preparePipe(crossProtocolResponsesFDs, "cross-protocol responses")) {
return;
}
vinfolog("Adding TCP Client thread");
{
- std::lock_guard<std::mutex> lock(d_mutex);
-
if (d_numthreads >= d_tcpclientthreads.size()) {
vinfolog("Adding a new TCP client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of TCP client threads with setMaxTCPClientThreads() in the configuration.", d_numthreads.load(), d_tcpclientthreads.size());
- close(crossProtocolFDs[0]);
- close(crossProtocolFDs[1]);
+ close(crossProtocolQueriesFDs[0]);
+ close(crossProtocolQueriesFDs[1]);
+ close(crossProtocolResponsesFDs[0]);
+ close(crossProtocolResponsesFDs[1]);
close(pipefds[0]);
close(pipefds[1]);
return;
/* from now on this side of the pipe will be managed by that object,
no need to worry about it */
- TCPWorkerThread worker(pipefds[1], crossProtocolFDs[1]);
+ TCPWorkerThread worker(pipefds[1], crossProtocolQueriesFDs[1], crossProtocolResponsesFDs[1]);
try {
- std::thread t1(tcpClientThread, pipefds[0], crossProtocolFDs[0]);
+ std::thread t1(tcpClientThread, pipefds[0], crossProtocolQueriesFDs[0], crossProtocolResponsesFDs[0], crossProtocolResponsesFDs[1]);
t1.detach();
}
catch (const std::runtime_error& e) {
/* the thread creation failed, don't leak */
errlog("Error creating a TCP thread: %s", e.what());
close(pipefds[0]);
+ close(crossProtocolQueriesFDs[0]);
+ close(crossProtocolResponsesFDs[0]);
return;
}
queueResponse(state, now, std::move(response));
}
+struct TCPCrossProtocolResponse
+{
+ TCPCrossProtocolResponse(TCPResponse&& response, std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now): d_response(std::move(response)), d_state(state), d_now(now)
+ {
+ }
+
+ TCPResponse d_response;
+ std::shared_ptr<IncomingTCPConnectionState> d_state;
+ struct timeval d_now;
+};
+
+class TCPCrossProtocolQuerySender : public TCPQuerySender
+{
+public:
+ TCPCrossProtocolQuerySender(std::shared_ptr<IncomingTCPConnectionState>& state, int responseDescriptor): d_state(state), d_responseDesc(responseDescriptor)
+ {
+ }
+
+ bool active() const override
+ {
+ return d_state->active();
+ }
+
+ const ClientState* getClientState() const override
+ {
+ return d_state->getClientState();
+ }
+
+ void handleResponse(const struct timeval& now, TCPResponse&& response) override
+ {
+ if (d_responseDesc == -1) {
+ throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender");
+ }
+
+ auto ptr = new TCPCrossProtocolResponse(std::move(response), d_state, now);
+ static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
+ ssize_t sent = write(d_responseDesc, &ptr, sizeof(ptr));
+ if (sent != sizeof(ptr)) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) {
+ vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full");
+ }
+ else {
+ vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror());
+ }
+ delete ptr;
+ }
+ }
+
+ void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
+ {
+ handleResponse(now, std::move(response));
+ }
+
+ void notifyIOError(IDState&& query, const struct timeval& now) override
+ {
+ TCPResponse response(PacketBuffer(), std::move(query), nullptr);
+ handleResponse(now, std::move(response));
+ }
+
+private:
+ std::shared_ptr<IncomingTCPConnectionState> d_state;
+ int d_responseDesc{-1};
+};
+
class TCPCrossProtocolQuery : public CrossProtocolQuery
{
public:
- TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr<DownstreamState>& ds, std::shared_ptr<IncomingTCPConnectionState>& sender): d_sender(sender)
+ TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr<DownstreamState>& ds, std::shared_ptr<TCPCrossProtocolQuerySender>& sender): d_sender(sender)
{
query = InternalQuery(std::move(buffer), std::move(ids));
downstream = ds;
}
private:
- std::shared_ptr<IncomingTCPConnectionState> d_sender;
+ std::shared_ptr<TCPCrossProtocolQuerySender> d_sender;
};
static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
++state->d_currentQueriesCount;
if (ds->isDoH()) {
- std::shared_ptr<TCPQuerySender> incoming = state;
- auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, state);
+ auto incoming = std::make_shared<TCPCrossProtocolQuerySender>(state, state->d_threadData.crossProtocolResponsesPipe);
+ auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, incoming);
ds->passCrossProtocolQuery(std::move(cpq));
return;
prependSizeToTCPQuery(state->d_buffer, 0);
-#warning FIXME: handle DoH backends here
auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now);
bool proxyProtocolPayloadAdded = false;
}
}
-static void tcpClientThread(int pipefd, int crossProtocolPipeFD)
+static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& param)
+{
+ TCPCrossProtocolResponse* tmp{nullptr};
+
+ ssize_t got = read(pipefd, &tmp, sizeof(tmp));
+ if (got == 0) {
+ throw std::runtime_error("EOF while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+ }
+ else if (got == -1) {
+ if (errno == EAGAIN || errno == EINTR) {
+ return;
+ }
+ throw std::runtime_error("Error while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror());
+ }
+ else if (got != sizeof(tmp)) {
+ throw std::runtime_error("Partial read while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+ }
+
+ auto response = std::move(*tmp);
+ delete tmp;
+ tmp = nullptr;
+
+ if (response.d_response.d_buffer.empty()) {
+ response.d_state->notifyIOError(std::move(response.d_response.d_idstate), response.d_now);
+ }
+ else if (response.d_response.d_idstate.qtype == QType::AXFR || response.d_response.d_idstate.qtype == QType::IXFR) {
+ response.d_state->handleXFRResponse(response.d_now, std::move(response.d_response));
+ }
+ else {
+ response.d_state->handleXFRResponse(response.d_now, std::move(response.d_response));
+ }
+}
+
+static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int crossProtocolResponsesListenPipeFD, int crossProtocolResponsesWritePipeFD)
{
/* we get launched with a pipe on which we receive file descriptors from clients that we own
from that point on */
setThreadName("dnsdist/tcpClie");
TCPClientThreadData data;
-
+ /* this is the writing end! */
+ data.crossProtocolResponsesPipe = crossProtocolResponsesWritePipeFD;
data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data);
- data.mplexer->addReadFD(crossProtocolPipeFD, handleCrossProtocolQuery, &data);
+ data.mplexer->addReadFD(crossProtocolQueriesPipeFD, handleCrossProtocolQuery, &data);
+ data.mplexer->addReadFD(crossProtocolResponsesListenPipeFD, handleCrossProtocolResponse, &data);
struct timeval now;
gettimeofday(&now, nullptr);
return true;
}
- const ClientState* getClientState() override
+ const ClientState* getClientState() const override
{
return &d_cs;
}
g_maxTCPClientThreads = 1;
}
+ /* we need to create the TCP worker threads before the
+ acceptor ones, otherwise we might crash when processing
+ the first TCP query */
g_tcpclientthreads = std::make_unique<TCPClientCollection>(*g_maxTCPClientThreads);
initDoHWorkers();
}
handleQueuedHealthChecks(*mplexer, true);
- /* we need to create the TCP worker threads before the
- acceptor ones, otherwise we might crash when processing
- the first TCP query */
- while (!g_tcpclientthreads->hasReachedMaxThreads()) {
- g_tcpclientthreads->addTCPClientThread();
- }
-
for(auto& cs : g_frontends) {
if (cs->dohFrontend != nullptr) {
#ifdef HAVE_DNS_OVER_HTTPS
return true;
}
- const ClientState* getClientState() override
+ const ClientState* getClientState() const override
{
return nullptr;
}
else if (ds->isDoH()) {
InternalQuery query(std::move(packet), IDState());
auto sender = std::shared_ptr<TCPQuerySender>(new HealthCheckQuerySender(data));
- if (!sendH2Query(ds, mplexer, sender, std::move(query))) {
+ if (!sendH2Query(ds, mplexer, sender, std::move(query), true)) {
updateHealthCheckResult(data->d_ds, data->d_initial, false);
}
}
#endif /* HAVE_NGHTTP2 */
}
-bool sendH2Query(const std::shared_ptr<DownstreamState>& ds, std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<TCPQuerySender>& sender, InternalQuery&& query)
+bool sendH2Query(const std::shared_ptr<DownstreamState>& ds, std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<TCPQuerySender>& sender, InternalQuery&& query, bool healthCheck)
{
#ifdef HAVE_NGHTTP2
struct timeval now;
gettimeofday(&now, nullptr);
auto newConnection = std::make_shared<DoHConnectionToBackend>(ds, mplexer, now);
- newConnection->setHealthCheck(true);
+ newConnection->setHealthCheck(healthCheck);
newConnection->queueQuery(sender, std::move(query));
return true;
#else /* HAVE_NGHTTP2 */
/* opens a new HTTP/2 connection to the supplied backend (attached to the supplied multiplexer), sends the query,
waits for the response to come back or an error to occur then notifies the sender, closing the connection. */
-bool sendH2Query(const std::shared_ptr<DownstreamState>& ds, std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<TCPQuerySender>& sender, InternalQuery&& query);
+bool sendH2Query(const std::shared_ptr<DownstreamState>& ds, std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<TCPQuerySender>& sender, InternalQuery&& query, bool healthCheck);
LocalHolders holders;
LocalStateHolder<vector<DNSDistResponseRuleAction> > localRespRuleActions;
std::unique_ptr<FDMultiplexer> mplexer{nullptr};
+ int crossProtocolResponsesPipe{-1};
};
class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this<IncomingTCPConnectionState>
return d_ioState != nullptr;
}
- const ClientState* getClientState() override
+ const ClientState* getClientState() const override
{
return d_ci.cs;
}
}
virtual bool active() const = 0;
- virtual const ClientState* getClientState() = 0;
+ virtual const ClientState* getClientState() const = 0;
virtual void handleResponse(const struct timeval& now, TCPResponse&& response) = 0;
virtual void handleXFRResponse(const struct timeval& now, TCPResponse&& response) = 0;
virtual void notifyIOError(IDState&& query, const struct timeval& now) = 0;
uint64_t pos = d_pos++;
++d_queued;
- return d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe;
+ return d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe.getHandle();
}
bool passConnectionToThread(std::unique_ptr<ConnectionInfo>&& conn)
}
uint64_t pos = d_pos++;
- auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe;
+ auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe.getHandle();
auto tmp = conn.release();
if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) {
}
uint64_t pos = d_pos++;
- auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_crossProtocolQueryPipe;
+ auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_crossProtocolQueriesPipe.getHandle();
auto tmp = cpq.release();
if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) {
--d_queued;
}
+private:
void addTCPClientThread();
-private:
struct TCPWorkerThread
{
TCPWorkerThread()
{
}
- TCPWorkerThread(int newConnPipe, int crossProtocolPipe) :
- d_newConnectionPipe(newConnPipe), d_crossProtocolQueryPipe(crossProtocolPipe)
- {
- }
-
- TCPWorkerThread(TCPWorkerThread&& rhs) :
- d_newConnectionPipe(rhs.d_newConnectionPipe), d_crossProtocolQueryPipe(rhs.d_crossProtocolQueryPipe)
+ TCPWorkerThread(int newConnPipe, int crossProtocolQueriesPipe, int crossProtocolResponsesPipe) :
+ d_newConnectionPipe(newConnPipe), d_crossProtocolQueriesPipe(crossProtocolQueriesPipe), d_crossProtocolResponsesPipe(crossProtocolResponsesPipe)
{
- rhs.d_newConnectionPipe = -1;
- rhs.d_crossProtocolQueryPipe = -1;
- }
-
- TCPWorkerThread& operator=(TCPWorkerThread&& rhs)
- {
- if (d_newConnectionPipe != -1) {
- close(d_newConnectionPipe);
- }
- if (d_crossProtocolQueryPipe != -1) {
- close(d_crossProtocolQueryPipe);
- }
-
- d_newConnectionPipe = rhs.d_newConnectionPipe;
- d_crossProtocolQueryPipe = rhs.d_crossProtocolQueryPipe;
- rhs.d_newConnectionPipe = -1;
- rhs.d_crossProtocolQueryPipe = -1;
-
- return *this;
}
+ TCPWorkerThread(TCPWorkerThread&& rhs) = default;
+ TCPWorkerThread& operator=(TCPWorkerThread&& rhs) = default;
TCPWorkerThread(const TCPWorkerThread& rhs) = delete;
TCPWorkerThread& operator=(const TCPWorkerThread&) = delete;
- ~TCPWorkerThread()
- {
- if (d_newConnectionPipe != -1) {
- close(d_newConnectionPipe);
- }
- if (d_crossProtocolQueryPipe != -1) {
- close(d_crossProtocolQueryPipe);
- }
- }
-
- int d_newConnectionPipe{-1};
- int d_crossProtocolQueryPipe{-1};
+ FDWrapper d_newConnectionPipe;
+ FDWrapper d_crossProtocolQueriesPipe;
+ FDWrapper d_crossProtocolResponsesPipe;
};
- std::mutex d_mutex;
std::vector<TCPWorkerThread> d_tcpclientthreads;
stat_t d_numthreads{0};
stat_t d_pos{0};
return true;
}
- const ClientState* getClientState() override
+ const ClientState* getClientState() const override
{
if (!du || !du->dsc || !du->dsc->cs) {
throw std::runtime_error("No query associated to this DoHTCPCrossQuerySender");
// Used in NID and L64 records
struct NodeOrLocatorID { uint8_t content[8]; };
+
+struct FDWrapper
+{
+ FDWrapper()
+ {
+ }
+
+ FDWrapper(int desc): d_fd(desc)
+ {
+ }
+
+ ~FDWrapper()
+ {
+ if (d_fd != -1) {
+ close(d_fd);
+ d_fd = -1;
+ }
+ }
+
+ FDWrapper(FDWrapper&& rhs): d_fd(rhs.d_fd)
+ {
+ rhs.d_fd = -1;
+ }
+
+ FDWrapper& operator=(FDWrapper&& rhs)
+ {
+ if (d_fd) {
+ close(d_fd);
+ }
+ d_fd = rhs.d_fd;
+ rhs.d_fd = -1;
+ return *this;
+ }
+
+ int getHandle() const
+ {
+ return d_fd;
+ }
+
+private:
+ int d_fd{-1};
+};
Socket& operator=(Socket&& rhs)
{
+ if (d_socket != -1) {
+ close(d_socket);
+ }
d_socket = rhs.d_socket;
rhs.d_socket = -1;
d_buffer = std::move(rhs.d_buffer);
(conn, _) = sock.accept()
except ssl.SSLError:
continue
+ except ConnectionResetError:
+ continue
+
conn.settimeout(5.0)
data = conn.recv(2)
if not data:
continue
except ConnectionResetError:
continue
+
conn.settimeout(5.0)
h2conn = h2.connection.H2Connection(config=config)
h2conn.initiate_connection()