#include "dnsdist-downstream-connection.hh"
#include "dolog.hh"
+#include "channel.hh"
#include "iputils.hh"
#include "libssl.hh"
#include "noinitvector.hh"
class DoHClientThreadData
{
public:
- DoHClientThreadData() :
- mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
+ DoHClientThreadData(pdns::channel::Receiver<CrossProtocolQuery>&& receiver) :
+ mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent())),
+ d_receiver(std::move(receiver))
{
}
std::unique_ptr<FDMultiplexer> mplexer{nullptr};
+ pdns::channel::Receiver<CrossProtocolQuery> d_receiver;
};
void DoHConnectionToBackend::handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param)
static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param)
{
auto threadData = boost::any_cast<DoHClientThreadData*>(param);
- CrossProtocolQuery* tmp{nullptr};
- ssize_t got = read(pipefd, &tmp, sizeof(tmp));
- if (got == 0) {
- throw std::runtime_error("EOF while reading from the DoH cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
- }
- else if (got == -1) {
- if (errno == EAGAIN || errno == EINTR) {
+ std::unique_ptr<CrossProtocolQuery> cpq{nullptr};
+ try {
+ auto tmp = threadData->d_receiver.receive();
+ if (!tmp) {
return;
}
- throw std::runtime_error("Error while reading from the DoH cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror());
+ cpq = std::move(*tmp);
}
- else if (got != sizeof(tmp)) {
- throw std::runtime_error("Partial read while reading from the DoH cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+ catch (const std::exception& e) {
+ throw std::runtime_error("Error while reading from the DoH cross-protocol channel:" + std::string(e.what()));
}
- try {
- struct timeval now;
- gettimeofday(&now, nullptr);
+ struct timeval now;
+ gettimeofday(&now, nullptr);
- std::shared_ptr<TCPQuerySender> tqs = tmp->getTCPQuerySender();
- auto query = std::move(tmp->query);
- auto downstreamServer = std::move(tmp->downstream);
- delete tmp;
- tmp = nullptr;
+ std::shared_ptr<TCPQuerySender> tqs = cpq->getTCPQuerySender();
+ auto query = std::move(cpq->query);
+ auto downstreamServer = std::move(cpq->downstream);
+ cpq.reset();
- try {
- auto downstream = t_downstreamDoHConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::move(query.d_proxyProtocolPayload));
- downstream->queueQuery(tqs, std::move(query));
- }
- catch (...) {
- tqs->notifyIOError(std::move(query.d_idstate), now);
- }
+ try {
+ auto downstream = t_downstreamDoHConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::move(query.d_proxyProtocolPayload));
+ downstream->queueQuery(tqs, std::move(query));
}
catch (...) {
- delete tmp;
- tmp = nullptr;
+ tqs->notifyIOError(std::move(query.d_idstate), now);
}
}
-static void dohClientThread(int crossProtocolPipeFD)
+static void dohClientThread(pdns::channel::Receiver<CrossProtocolQuery>&& receiver)
{
setThreadName("dnsdist/dohClie");
try {
- DoHClientThreadData data;
- data.mplexer->addReadFD(crossProtocolPipeFD, handleCrossProtocolQuery, &data);
+ DoHClientThreadData data(std::move(receiver));
+ data.mplexer->addReadFD(data.d_receiver.getDescriptor(), handleCrossProtocolQuery, &data);
struct timeval now;
gettimeofday(&now, nullptr);
{
}
- DoHWorkerThread(int crossProtocolPipe) :
- d_crossProtocolQueryPipe(crossProtocolPipe)
+ DoHWorkerThread(pdns::channel::Sender<CrossProtocolQuery>&& sender) :
+ d_sender(std::move(sender))
{
}
DoHWorkerThread(DoHWorkerThread&& rhs) :
- d_crossProtocolQueryPipe(rhs.d_crossProtocolQueryPipe)
+ d_sender(std::move(rhs.d_sender))
{
- rhs.d_crossProtocolQueryPipe = -1;
}
DoHWorkerThread& operator=(DoHWorkerThread&& rhs)
{
- if (d_crossProtocolQueryPipe != -1) {
- close(d_crossProtocolQueryPipe);
- }
-
- d_crossProtocolQueryPipe = rhs.d_crossProtocolQueryPipe;
- rhs.d_crossProtocolQueryPipe = -1;
-
+ d_sender = std::move(rhs.d_sender);
return *this;
}
DoHWorkerThread(const DoHWorkerThread& rhs) = delete;
DoHWorkerThread& operator=(const DoHWorkerThread&) = delete;
- ~DoHWorkerThread()
- {
- if (d_crossProtocolQueryPipe != -1) {
- close(d_crossProtocolQueryPipe);
- }
- }
-
- int d_crossProtocolQueryPipe{-1};
+ pdns::channel::Sender<CrossProtocolQuery> d_sender;
};
DoHClientCollection::DoHClientCollection(size_t numberOfThreads) :
}
uint64_t pos = d_pos++;
- auto pipe = d_clientThreads.at(pos % d_numberOfThreads).d_crossProtocolQueryPipe;
- auto tmp = cpq.release();
-
- if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) {
- delete tmp;
+ if (!d_clientThreads.at(pos % d_numberOfThreads).d_sender.send(std::move(cpq))) {
++g_stats.outgoingDoHQueryPipeFull;
- tmp = nullptr;
return false;
}
void DoHClientCollection::addThread()
{
#ifdef HAVE_NGHTTP2
- auto preparePipe = [](int fds[2], const std::string& type) -> bool {
- if (pipe(fds) < 0) {
- errlog("Error creating the DoH thread %s pipe: %s", type, stringerror());
- return false;
- }
-
- if (!setNonBlocking(fds[0])) {
- int err = errno;
- close(fds[0]);
- close(fds[1]);
- errlog("Error setting the DoH thread %s pipe non-blocking: %s", type, stringerror(err));
- return false;
- }
-
- if (!setNonBlocking(fds[1])) {
- int err = errno;
- close(fds[0]);
- close(fds[1]);
- errlog("Error setting the DoH thread %s pipe non-blocking: %s", type, stringerror(err));
- return false;
- }
-
- if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(fds[0]) < g_tcpInternalPipeBufferSize) {
- setPipeBufferSize(fds[0], g_tcpInternalPipeBufferSize);
- }
-
- return true;
- };
-
- int crossProtocolFDs[2] = {-1, -1};
- if (!preparePipe(crossProtocolFDs, "cross-protocol")) {
- return;
- }
-
- vinfolog("Adding DoH Client thread");
+ try {
+ auto [sender, receiver] = pdns::channel::createObjectQueue<CrossProtocolQuery>(true, g_tcpInternalPipeBufferSize);
- {
+ vinfolog("Adding DoH Client thread");
std::lock_guard<std::mutex> lock(d_mutex);
if (d_numberOfThreads >= d_clientThreads.size()) {
vinfolog("Adding a new DoH client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of DoH client threads with setMaxDoHClientThreads() in the configuration.", d_numberOfThreads, d_clientThreads.size());
- close(crossProtocolFDs[0]);
- close(crossProtocolFDs[1]);
return;
}
- /* from now on this side of the pipe will be managed by that object,
- no need to worry about it */
- DoHWorkerThread worker(crossProtocolFDs[1]);
+ DoHWorkerThread worker(std::move(sender));
try {
- std::thread t1(dohClientThread, crossProtocolFDs[0]);
+ std::thread t1(dohClientThread, std::move(receiver));
t1.detach();
}
catch (const std::runtime_error& e) {
- /* the thread creation failed, don't leak */
+ /* the thread creation failed */
errlog("Error creating a DoH thread: %s", e.what());
- close(crossProtocolFDs[0]);
return;
}
d_clientThreads.at(d_numberOfThreads) = std::move(worker);
++d_numberOfThreads;
}
+ catch (const std::exception& e) {
+ errlog("Error creating the DoH channel: %s", e.what());
+ return;
+ }
#else /* HAVE_NGHTTP2 */
throw std::runtime_error("DoHClientCollection::addThread() called but nghttp2 support is not available");
#endif /* HAVE_NGHTTP2 */