dns_random.hh \
dnscrypt.cc dnscrypt.hh \
dnsdist-actions-factory-generated.hh \
+ dnsdist-actions-factory.cc dnsdist-actions-factory.hh \
dnsdist-actions.cc dnsdist-actions.hh \
dnsdist-async.cc dnsdist-async.hh \
dnsdist-backend.cc dnsdist-backend.hh \
if HAVE_RE2
dnsdist_LDADD += $(RE2_LIBS)
+testrunner_LDADD += $(RE2_LIBS)
endif
if HAVE_LIBSSL
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#include <unordered_map>
+
+#include "dnsdist-actions-factory.hh"
+
+#include "config.h"
+#include "dnsdist.hh"
+#include "dnsdist-async.hh"
+#include "dnsdist-dnsparser.hh"
+#include "dnsdist-ecs.hh"
+#include "dnsdist-edns.hh"
+#include "dnsdist-lua.hh"
+#include "dnsdist-lua-ffi.hh"
+#include "dnsdist-mac-address.hh"
+#include "dnsdist-protobuf.hh"
+#include "dnsdist-proxy-protocol.hh"
+#include "dnsdist-kvs.hh"
+#include "dnsdist-rule-chains.hh"
+#include "dnsdist-self-answers.hh"
+#include "dnsdist-snmp.hh"
+
+#include "dnstap.hh"
+#include "dnswriter.hh"
+#include "ednsoptions.hh"
+#include "fstrm_logger.hh"
+#include "ipcipher.hh"
+#include "remote_logger.hh"
+#include "svc-records.hh"
+#include "threadname.hh"
+
+namespace dnsdist::actions
+{
+class DropAction : public DNSAction
+{
+public:
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ return Action::Drop;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "drop";
+ }
+};
+
+class AllowAction : public DNSAction
+{
+public:
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ return Action::Allow;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "allow";
+ }
+};
+
+class NoneAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "no op";
+ }
+};
+
+class QPSAction : public DNSAction
+{
+public:
+ QPSAction(int limit) :
+ d_qps(QPSLimiter(limit, limit))
+ {
+ }
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ if (d_qps.lock()->check()) {
+ return Action::None;
+ }
+ return Action::Drop;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "qps limit to " + std::to_string(d_qps.lock()->getRate());
+ }
+
+private:
+ mutable LockGuarded<QPSLimiter> d_qps;
+};
+
+class DelayAction : public DNSAction
+{
+public:
+ DelayAction(int msec) :
+ d_msec(msec)
+ {
+ }
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ *ruleresult = std::to_string(d_msec);
+ return Action::Delay;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "delay by " + std::to_string(d_msec) + " ms";
+ }
+
+private:
+ int d_msec;
+};
+
+class TeeAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ TeeAction(const ComboAddress& rca, const std::optional<ComboAddress>& lca, bool addECS = false, bool addProxyProtocol = false);
+ TeeAction(TeeAction& other) = delete;
+ TeeAction(TeeAction&& other) = delete;
+ TeeAction& operator=(TeeAction& other) = delete;
+ TeeAction& operator=(TeeAction&& other) = delete;
+ ~TeeAction() override;
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override;
+ [[nodiscard]] std::string toString() const override;
+ std::map<std::string, double> getStats() const override;
+
+private:
+ void worker();
+
+ ComboAddress d_remote;
+ std::thread d_worker;
+ Socket d_socket;
+ mutable std::atomic<unsigned long> d_senderrors{0};
+ unsigned long d_recverrors{0};
+ mutable std::atomic<unsigned long> d_queries{0};
+ stat_t d_responses{0};
+ stat_t d_nxdomains{0};
+ stat_t d_servfails{0};
+ stat_t d_refuseds{0};
+ stat_t d_formerrs{0};
+ stat_t d_notimps{0};
+ stat_t d_noerrors{0};
+ mutable stat_t d_tcpdrops{0};
+ stat_t d_otherrcode{0};
+ std::atomic<bool> d_pleaseQuit{false};
+ bool d_addECS{false};
+ bool d_addProxyProtocol{false};
+};
+
+TeeAction::TeeAction(const ComboAddress& rca, const std::optional<ComboAddress>& lca, bool addECS, bool addProxyProtocol) :
+ d_remote(rca), d_socket(d_remote.sin4.sin_family, SOCK_DGRAM, 0), d_addECS(addECS), d_addProxyProtocol(addProxyProtocol)
+{
+ if (lca) {
+ d_socket.bind(*lca, false);
+ }
+ d_socket.connect(d_remote);
+ d_socket.setNonBlocking();
+ d_worker = std::thread([this]() {
+ worker();
+ });
+}
+
+TeeAction::~TeeAction()
+{
+ d_pleaseQuit = true;
+ close(d_socket.releaseHandle());
+ d_worker.join();
+}
+
+DNSAction::Action TeeAction::operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const
+{
+ if (dnsquestion->overTCP()) {
+ d_tcpdrops++;
+ return DNSAction::Action::None;
+ }
+
+ d_queries++;
+
+ PacketBuffer query;
+ if (d_addECS) {
+ query = dnsquestion->getData();
+ bool ednsAdded = false;
+ bool ecsAdded = false;
+
+ std::string newECSOption;
+ generateECSOption(dnsquestion->ecs ? dnsquestion->ecs->getNetwork() : dnsquestion->ids.origRemote, newECSOption, dnsquestion->ecs ? dnsquestion->ecs->getBits() : dnsquestion->ecsPrefixLength);
+
+ if (!handleEDNSClientSubnet(query, dnsquestion->getMaximumSize(), dnsquestion->ids.qname.wirelength(), ednsAdded, ecsAdded, dnsquestion->ecsOverride, newECSOption)) {
+ return DNSAction::Action::None;
+ }
+ }
+
+ if (d_addProxyProtocol) {
+ auto proxyPayload = getProxyProtocolPayload(*dnsquestion);
+ if (query.empty()) {
+ query = dnsquestion->getData();
+ }
+ if (!addProxyProtocol(query, proxyPayload)) {
+ return DNSAction::Action::None;
+ }
+ }
+
+ {
+ const PacketBuffer& payload = query.empty() ? dnsquestion->getData() : query;
+ auto res = send(d_socket.getHandle(), payload.data(), payload.size(), 0);
+
+ if (res <= 0) {
+ d_senderrors++;
+ }
+ }
+
+ return DNSAction::Action::None;
+}
+
+std::string TeeAction::toString() const
+{
+ return "tee to " + d_remote.toStringWithPort();
+}
+
+std::map<std::string, double> TeeAction::getStats() const
+{
+ return {{"queries", d_queries},
+ {"responses", d_responses},
+ {"recv-errors", d_recverrors},
+ {"send-errors", d_senderrors},
+ {"noerrors", d_noerrors},
+ {"nxdomains", d_nxdomains},
+ {"refuseds", d_refuseds},
+ {"servfails", d_servfails},
+ {"other-rcode", d_otherrcode},
+ {"tcp-drops", d_tcpdrops}};
+}
+
+void TeeAction::worker()
+{
+ setThreadName("dnsdist/TeeWork");
+ std::array<char, dnsdist::configuration::s_udpIncomingBufferSize> packet{};
+ ssize_t res = 0;
+ const dnsheader_aligned dnsheader(packet.data());
+ for (;;) {
+ res = waitForData(d_socket.getHandle(), 0, 250000);
+ if (d_pleaseQuit) {
+ break;
+ }
+
+ if (res < 0) {
+ usleep(250000);
+ continue;
+ }
+ if (res == 0) {
+ continue;
+ }
+ res = recv(d_socket.getHandle(), packet.data(), packet.size(), 0);
+ if (static_cast<size_t>(res) <= sizeof(struct dnsheader)) {
+ d_recverrors++;
+ }
+ else {
+ d_responses++;
+ }
+
+ // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well
+ if (dnsheader->rcode == RCode::NoError) {
+ d_noerrors++;
+ }
+ // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well
+ else if (dnsheader->rcode == RCode::ServFail) {
+ d_servfails++;
+ }
+ // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well
+ else if (dnsheader->rcode == RCode::NXDomain) {
+ d_nxdomains++;
+ }
+ // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well
+ else if (dnsheader->rcode == RCode::Refused) {
+ d_refuseds++;
+ }
+ // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well
+ else if (dnsheader->rcode == RCode::FormErr) {
+ d_formerrs++;
+ }
+ // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well
+ else if (dnsheader->rcode == RCode::NotImp) {
+ d_notimps++;
+ }
+ }
+}
+
+class PoolAction : public DNSAction
+{
+public:
+ PoolAction(std::string pool, bool stopProcessing) :
+ d_pool(std::move(pool)), d_stopProcessing(stopProcessing) {}
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ if (d_stopProcessing) {
+ /* we need to do it that way to keep compatiblity with custom Lua actions returning DNSAction.Pool, 'poolname' */
+ *ruleresult = d_pool;
+ return Action::Pool;
+ }
+ dnsquestion->ids.poolName = d_pool;
+ return Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "to pool " + d_pool;
+ }
+
+private:
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+ const std::string d_pool;
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+ const bool d_stopProcessing;
+};
+
+class QPSPoolAction : public DNSAction
+{
+public:
+ QPSPoolAction(unsigned int limit, std::string pool, bool stopProcessing) :
+ d_qps(QPSLimiter(limit, limit)), d_pool(std::move(pool)), d_stopProcessing(stopProcessing) {}
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ if (d_qps.lock()->check()) {
+ if (d_stopProcessing) {
+ /* we need to do it that way to keep compatiblity with custom Lua actions returning DNSAction.Pool, 'poolname' */
+ *ruleresult = d_pool;
+ return Action::Pool;
+ }
+ dnsquestion->ids.poolName = d_pool;
+ }
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "max " + std::to_string(d_qps.lock()->getRate()) + " to pool " + d_pool;
+ }
+
+private:
+ mutable LockGuarded<QPSLimiter> d_qps;
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+ const std::string d_pool;
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+ const bool d_stopProcessing;
+};
+
+class RCodeAction : public DNSAction
+{
+public:
+ RCodeAction(uint8_t rcode, const dnsdist::ResponseConfig& responseConfig) :
+ d_responseConfig(responseConfig), d_rcode(rcode) {}
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
+ header.rcode = d_rcode;
+ header.qr = true; // for good measure
+ setResponseHeadersFromConfig(header, d_responseConfig);
+ return true;
+ });
+ return Action::HeaderModify;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "set rcode " + std::to_string(d_rcode);
+ }
+
+private:
+ dnsdist::ResponseConfig d_responseConfig;
+ uint8_t d_rcode;
+};
+
+class ERCodeAction : public DNSAction
+{
+public:
+ ERCodeAction(uint8_t rcode, dnsdist::ResponseConfig responseConfig) :
+ d_responseConfig(std::move(responseConfig)), d_rcode(rcode)
+ {
+ }
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
+ header.rcode = (d_rcode & 0xF);
+ header.qr = true; // for good measure
+ setResponseHeadersFromConfig(header, d_responseConfig);
+ return true;
+ });
+ dnsquestion->ednsRCode = ((d_rcode & 0xFFF0) >> 4);
+ return Action::HeaderModify;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "set ercode " + ERCode::to_s(d_rcode);
+ }
+
+private:
+ dnsdist::ResponseConfig d_responseConfig;
+ uint8_t d_rcode;
+};
+
+class SpoofSVCAction : public DNSAction
+{
+public:
+ SpoofSVCAction(const std::vector<SVCRecordParameters>& parameters, const dnsdist::ResponseConfig& responseConfig) :
+ d_responseConfig(responseConfig)
+ {
+ d_payloads.reserve(parameters.size());
+
+ for (const auto& param : parameters) {
+ std::vector<uint8_t> payload;
+ if (!generateSVCPayload(payload, param)) {
+ throw std::runtime_error("Unable to generate a valid SVC record from the supplied parameters");
+ }
+
+ d_payloads.push_back(std::move(payload));
+
+ for (const auto& hint : param.ipv4hints) {
+ d_additionals4.insert({param.target, ComboAddress(hint)});
+ }
+
+ for (const auto& hint : param.ipv6hints) {
+ d_additionals6.insert({param.target, ComboAddress(hint)});
+ }
+ }
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ if (!dnsdist::svc::generateSVCResponse(*dnsquestion, d_payloads, d_additionals4, d_additionals6, d_responseConfig)) {
+ return Action::None;
+ }
+
+ return Action::HeaderModify;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "spoof SVC record ";
+ }
+
+private:
+ dnsdist::ResponseConfig d_responseConfig;
+ std::vector<std::vector<uint8_t>> d_payloads{};
+ std::set<std::pair<DNSName, ComboAddress>> d_additionals4{};
+ std::set<std::pair<DNSName, ComboAddress>> d_additionals6{};
+};
+
+class TCAction : public DNSAction
+{
+public:
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ return Action::Truncate;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "tc=1 answer";
+ }
+};
+
+class TCResponseAction : public DNSResponseAction
+{
+public:
+ DNSResponseAction::Action operator()(DNSResponse* dnsResponse, std::string* ruleresult) const override
+ {
+ return Action::Truncate;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "tc=1 answer";
+ }
+};
+
+class LuaAction : public DNSAction
+{
+public:
+ LuaAction(LuaActionFunction func) :
+ d_func(std::move(func))
+ {}
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ try {
+ DNSAction::Action result{};
+ {
+ auto lock = g_lua.lock();
+ auto ret = d_func(dnsquestion);
+ if (ruleresult != nullptr) {
+ if (boost::optional<std::string> rule = std::get<1>(ret)) {
+ *ruleresult = *rule;
+ }
+ else {
+ // default to empty string
+ ruleresult->clear();
+ }
+ }
+ result = static_cast<Action>(std::get<0>(ret));
+ }
+ dnsdist::handleQueuedAsynchronousEvents();
+ return result;
+ }
+ catch (const std::exception& e) {
+ warnlog("LuaAction failed inside Lua, returning ServFail: %s", e.what());
+ }
+ catch (...) {
+ warnlog("LuaAction failed inside Lua, returning ServFail: [unknown exception]");
+ }
+ return DNSAction::Action::ServFail;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "Lua script";
+ }
+
+private:
+ LuaActionFunction d_func;
+};
+
+class LuaResponseAction : public DNSResponseAction
+{
+public:
+ LuaResponseAction(LuaResponseActionFunction func) :
+ d_func(std::move(func))
+ {}
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ try {
+ DNSResponseAction::Action result{};
+ {
+ auto lock = g_lua.lock();
+ auto ret = d_func(response);
+ if (ruleresult != nullptr) {
+ if (boost::optional<std::string> rule = std::get<1>(ret)) {
+ *ruleresult = *rule;
+ }
+ else {
+ // default to empty string
+ ruleresult->clear();
+ }
+ }
+ result = static_cast<Action>(std::get<0>(ret));
+ }
+ dnsdist::handleQueuedAsynchronousEvents();
+ return result;
+ }
+ catch (const std::exception& e) {
+ warnlog("LuaResponseAction failed inside Lua, returning ServFail: %s", e.what());
+ }
+ catch (...) {
+ warnlog("LuaResponseAction failed inside Lua, returning ServFail: [unknown exception]");
+ }
+ return DNSResponseAction::Action::ServFail;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "Lua response script";
+ }
+
+private:
+ LuaResponseActionFunction d_func;
+};
+
+class LuaFFIAction : public DNSAction
+{
+public:
+ LuaFFIAction(LuaActionFFIFunction func) :
+ d_func(std::move(func))
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ dnsdist_ffi_dnsquestion_t dqffi(dnsquestion);
+ try {
+ DNSAction::Action result{};
+ {
+ auto lock = g_lua.lock();
+ auto ret = d_func(&dqffi);
+ if (ruleresult != nullptr) {
+ if (dqffi.result) {
+ *ruleresult = *dqffi.result;
+ }
+ else {
+ // default to empty string
+ ruleresult->clear();
+ }
+ }
+ result = static_cast<DNSAction::Action>(ret);
+ }
+ dnsdist::handleQueuedAsynchronousEvents();
+ return result;
+ }
+ catch (const std::exception& e) {
+ warnlog("LuaFFIAction failed inside Lua, returning ServFail: %s", e.what());
+ }
+ catch (...) {
+ warnlog("LuaFFIAction failed inside Lua, returning ServFail: [unknown exception]");
+ }
+ return DNSAction::Action::ServFail;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "Lua FFI script";
+ }
+
+private:
+ LuaActionFFIFunction d_func;
+};
+
+class LuaFFIPerThreadAction : public DNSAction
+{
+public:
+ LuaFFIPerThreadAction(std::string code) :
+ d_functionCode(std::move(code)), d_functionID(s_functionsCounter++)
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ try {
+ auto& state = t_perThreadStates[d_functionID];
+ if (!state.d_initialized) {
+ setupLuaFFIPerThreadContext(state.d_luaContext);
+ /* mark the state as initialized first so if there is a syntax error
+ we only try to execute the code once */
+ state.d_initialized = true;
+ state.d_func = state.d_luaContext.executeCode<LuaActionFFIFunction>(d_functionCode);
+ }
+
+ if (!state.d_func) {
+ /* the function was not properly initialized */
+ return DNSAction::Action::None;
+ }
+
+ dnsdist_ffi_dnsquestion_t dqffi(dnsquestion);
+ auto ret = state.d_func(&dqffi);
+ if (ruleresult != nullptr) {
+ if (dqffi.result) {
+ *ruleresult = *dqffi.result;
+ }
+ else {
+ // default to empty string
+ ruleresult->clear();
+ }
+ }
+ dnsdist::handleQueuedAsynchronousEvents();
+ return static_cast<DNSAction::Action>(ret);
+ }
+ catch (const std::exception& e) {
+ warnlog("LuaFFIPerThreadAction failed inside Lua, returning ServFail: %s", e.what());
+ }
+ catch (...) {
+ warnlog("LuaFFIPerthreadAction failed inside Lua, returning ServFail: [unknown exception]");
+ }
+ return DNSAction::Action::ServFail;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "Lua FFI per-thread script";
+ }
+
+private:
+ struct PerThreadState
+ {
+ LuaContext d_luaContext;
+ LuaActionFFIFunction d_func;
+ bool d_initialized{false};
+ };
+ static std::atomic<uint64_t> s_functionsCounter;
+ static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates;
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+ const std::string d_functionCode;
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+ const uint64_t d_functionID;
+};
+
+std::atomic<uint64_t> LuaFFIPerThreadAction::s_functionsCounter = 0;
+thread_local std::map<uint64_t, LuaFFIPerThreadAction::PerThreadState> LuaFFIPerThreadAction::t_perThreadStates;
+
+class LuaFFIResponseAction : public DNSResponseAction
+{
+public:
+ LuaFFIResponseAction(LuaResponseActionFFIFunction func) :
+ d_func(std::move(func))
+ {
+ }
+
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ dnsdist_ffi_dnsresponse_t ffiResponse(response);
+ try {
+ DNSResponseAction::Action result{};
+ {
+ auto lock = g_lua.lock();
+ auto ret = d_func(&ffiResponse);
+ if (ruleresult != nullptr) {
+ if (ffiResponse.result) {
+ *ruleresult = *ffiResponse.result;
+ }
+ else {
+ // default to empty string
+ ruleresult->clear();
+ }
+ }
+ result = static_cast<DNSResponseAction::Action>(ret);
+ }
+ dnsdist::handleQueuedAsynchronousEvents();
+ return result;
+ }
+ catch (const std::exception& e) {
+ warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: %s", e.what());
+ }
+ catch (...) {
+ warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: [unknown exception]");
+ }
+ return DNSResponseAction::Action::ServFail;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "Lua FFI script";
+ }
+
+private:
+ LuaResponseActionFFIFunction d_func;
+};
+
+class LuaFFIPerThreadResponseAction : public DNSResponseAction
+{
+public:
+ LuaFFIPerThreadResponseAction(std::string code) :
+ d_functionCode(std::move(code)), d_functionID(s_functionsCounter++)
+ {
+ }
+
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ try {
+ auto& state = t_perThreadStates[d_functionID];
+ if (!state.d_initialized) {
+ setupLuaFFIPerThreadContext(state.d_luaContext);
+ /* mark the state as initialized first so if there is a syntax error
+ we only try to execute the code once */
+ state.d_initialized = true;
+ state.d_func = state.d_luaContext.executeCode<LuaResponseActionFFIFunction>(d_functionCode);
+ }
+
+ if (!state.d_func) {
+ /* the function was not properly initialized */
+ return DNSResponseAction::Action::None;
+ }
+
+ dnsdist_ffi_dnsresponse_t ffiResponse(response);
+ auto ret = state.d_func(&ffiResponse);
+ if (ruleresult != nullptr) {
+ if (ffiResponse.result) {
+ *ruleresult = *ffiResponse.result;
+ }
+ else {
+ // default to empty string
+ ruleresult->clear();
+ }
+ }
+ dnsdist::handleQueuedAsynchronousEvents();
+ return static_cast<DNSResponseAction::Action>(ret);
+ }
+ catch (const std::exception& e) {
+ warnlog("LuaFFIPerThreadResponseAction failed inside Lua, returning ServFail: %s", e.what());
+ }
+ catch (...) {
+ warnlog("LuaFFIPerthreadResponseAction failed inside Lua, returning ServFail: [unknown exception]");
+ }
+ return DNSResponseAction::Action::ServFail;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "Lua FFI per-thread script";
+ }
+
+private:
+ struct PerThreadState
+ {
+ LuaContext d_luaContext;
+ LuaResponseActionFFIFunction d_func;
+ bool d_initialized{false};
+ };
+
+ static std::atomic<uint64_t> s_functionsCounter;
+ static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates;
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+ const std::string d_functionCode;
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+ const uint64_t d_functionID;
+};
+
+std::atomic<uint64_t> LuaFFIPerThreadResponseAction::s_functionsCounter = 0;
+thread_local std::map<uint64_t, LuaFFIPerThreadResponseAction::PerThreadState> LuaFFIPerThreadResponseAction::t_perThreadStates;
+
+class SpoofAction : public DNSAction
+{
+public:
+ SpoofAction(const vector<ComboAddress>& addrs, const dnsdist::ResponseConfig& responseConfig) :
+ d_responseConfig(responseConfig), d_addrs(addrs)
+ {
+ for (const auto& addr : d_addrs) {
+ if (addr.isIPv4()) {
+ d_types.insert(QType::A);
+ }
+ else if (addr.isIPv6()) {
+ d_types.insert(QType::AAAA);
+ }
+ }
+
+ if (!d_addrs.empty()) {
+ d_types.insert(QType::ANY);
+ }
+ }
+
+ SpoofAction(const DNSName& cname, const dnsdist::ResponseConfig& responseConfig) :
+ d_responseConfig(responseConfig), d_cname(cname)
+ {
+ }
+
+ SpoofAction(const PacketBuffer& rawresponse) :
+ d_raw(rawresponse)
+ {
+ }
+
+ SpoofAction(const vector<std::string>& raws, std::optional<uint16_t> typeForAny, const dnsdist::ResponseConfig& responseConfig) :
+ d_responseConfig(responseConfig), d_rawResponses(raws), d_rawTypeForAny(typeForAny)
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, string* ruleresult) const override;
+
+ string toString() const override
+ {
+ string ret = "spoof in ";
+ if (!d_cname.empty()) {
+ ret += d_cname.toString() + " ";
+ }
+ if (d_rawResponses.size() > 0) {
+ ret += "raw bytes ";
+ }
+ else {
+ for (const auto& a : d_addrs)
+ ret += a.toString() + " ";
+ }
+ return ret;
+ }
+
+private:
+ dnsdist::ResponseConfig d_responseConfig;
+ std::vector<ComboAddress> d_addrs;
+ std::unordered_set<uint16_t> d_types;
+ std::vector<std::string> d_rawResponses;
+ PacketBuffer d_raw;
+ DNSName d_cname;
+ std::optional<uint16_t> d_rawTypeForAny{};
+};
+
+DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const
+{
+ uint16_t qtype = dnsquestion->ids.qtype;
+ // do we even have a response?
+ if (d_cname.empty() && d_rawResponses.empty() &&
+ // make sure pre-forged response is greater than sizeof(dnsheader)
+ (d_raw.size() < sizeof(dnsheader)) && d_types.count(qtype) == 0) {
+ return Action::None;
+ }
+
+ if (d_raw.size() >= sizeof(dnsheader)) {
+ dnsdist::self_answers::generateAnswerFromRawPacket(*dnsquestion, d_raw);
+ return Action::HeaderModify;
+ }
+
+ if (!d_cname.empty()) {
+ if (dnsdist::self_answers::generateAnswerFromCNAME(*dnsquestion, d_cname, d_responseConfig)) {
+ return Action::HeaderModify;
+ }
+ }
+ else if (!d_rawResponses.empty()) {
+ if (dnsdist::self_answers::generateAnswerFromRDataEntries(*dnsquestion, d_rawResponses, d_rawTypeForAny, d_responseConfig)) {
+ return Action::HeaderModify;
+ }
+ }
+ else {
+ if (dnsdist::self_answers::generateAnswerFromIPAddresses(*dnsquestion, d_addrs, d_responseConfig)) {
+ return Action::HeaderModify;
+ }
+ }
+
+ return Action::None;
+}
+
+class SetMacAddrAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ SetMacAddrAction(uint16_t code) :
+ d_code(code)
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ dnsdist::MacAddress mac{};
+ int res = dnsdist::MacAddressesCache::get(dnsquestion->ids.origRemote, mac.data(), mac.size());
+ if (res != 0) {
+ return Action::None;
+ }
+
+ std::string optRData;
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+ generateEDNSOption(d_code, reinterpret_cast<const char*>(mac.data()), optRData);
+
+ if (dnsquestion->getHeader()->arcount > 0) {
+ bool ednsAdded = false;
+ bool optionAdded = false;
+ PacketBuffer newContent;
+ newContent.reserve(dnsquestion->getData().size());
+
+ if (!slowRewriteEDNSOptionInQueryWithRecords(dnsquestion->getData(), newContent, ednsAdded, d_code, optionAdded, true, optRData)) {
+ return Action::None;
+ }
+
+ if (newContent.size() > dnsquestion->getMaximumSize()) {
+ return Action::None;
+ }
+
+ dnsquestion->getMutableData() = std::move(newContent);
+ if (!dnsquestion->ids.ednsAdded && ednsAdded) {
+ dnsquestion->ids.ednsAdded = true;
+ }
+
+ return Action::None;
+ }
+
+ auto& data = dnsquestion->getMutableData();
+ if (generateOptRR(optRData, data, dnsquestion->getMaximumSize(), dnsdist::configuration::s_EdnsUDPPayloadSize, 0, false)) {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) {
+ header.arcount = htons(1);
+ return true;
+ });
+ // make sure that any EDNS sent by the backend is removed before forwarding the response to the client
+ dnsquestion->ids.ednsAdded = true;
+ }
+
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "add EDNS MAC (code=" + std::to_string(d_code) + ")";
+ }
+
+private:
+ uint16_t d_code{3};
+};
+
+class SetEDNSOptionAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ SetEDNSOptionAction(uint16_t code, std::string data) :
+ d_code(code), d_data(std::move(data))
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ setEDNSOption(*dnsquestion, d_code, d_data);
+ return Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "add EDNS Option (code=" + std::to_string(d_code) + ")";
+ }
+
+private:
+ uint16_t d_code;
+ std::string d_data;
+};
+
+class SetNoRecurseAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) {
+ header.rd = false;
+ return true;
+ });
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "set rd=0";
+ }
+};
+
+class LogAction : public DNSAction, public boost::noncopyable
+{
+public:
+ // this action does not stop the processing
+ LogAction() = default;
+
+ LogAction(const std::string& str, bool binary = true, bool append = false, bool buffered = true, bool verboseOnly = true, bool includeTimestamp = false) :
+ d_fname(str), d_binary(binary), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp), d_append(append), d_buffered(buffered)
+ {
+ if (str.empty()) {
+ return;
+ }
+
+ if (!reopenLogFile()) {
+ throw std::runtime_error("Unable to open file '" + str + "' for logging: " + stringerror());
+ }
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ auto filepointer = std::atomic_load_explicit(&d_fp, std::memory_order_acquire);
+ if (!filepointer) {
+ if (!d_verboseOnly || dnsdist::configuration::getCurrentRuntimeConfiguration().d_verbose) {
+ if (d_includeTimestamp) {
+ infolog("[%u.%u] Packet from %s for %s %s with id %d", static_cast<unsigned long long>(dnsquestion->getQueryRealTime().tv_sec), static_cast<unsigned long>(dnsquestion->getQueryRealTime().tv_nsec), dnsquestion->ids.origRemote.toStringWithPort(), dnsquestion->ids.qname.toString(), QType(dnsquestion->ids.qtype).toString(), dnsquestion->getHeader()->id);
+ }
+ else {
+ infolog("Packet from %s for %s %s with id %d", dnsquestion->ids.origRemote.toStringWithPort(), dnsquestion->ids.qname.toString(), QType(dnsquestion->ids.qtype).toString(), dnsquestion->getHeader()->id);
+ }
+ }
+ }
+ else {
+ if (d_binary) {
+ const auto& out = dnsquestion->ids.qname.getStorage();
+ if (d_includeTimestamp) {
+ auto tv_sec = static_cast<uint64_t>(dnsquestion->getQueryRealTime().tv_sec);
+ auto tv_nsec = static_cast<uint32_t>(dnsquestion->getQueryRealTime().tv_nsec);
+ fwrite(&tv_sec, sizeof(tv_sec), 1, filepointer.get());
+ fwrite(&tv_nsec, sizeof(tv_nsec), 1, filepointer.get());
+ }
+ uint16_t queryId = dnsquestion->getHeader()->id;
+ fwrite(&queryId, sizeof(queryId), 1, filepointer.get());
+ fwrite(out.c_str(), 1, out.size(), filepointer.get());
+ fwrite(&dnsquestion->ids.qtype, sizeof(dnsquestion->ids.qtype), 1, filepointer.get());
+ fwrite(&dnsquestion->ids.origRemote.sin4.sin_family, sizeof(dnsquestion->ids.origRemote.sin4.sin_family), 1, filepointer.get());
+ if (dnsquestion->ids.origRemote.sin4.sin_family == AF_INET) {
+ fwrite(&dnsquestion->ids.origRemote.sin4.sin_addr.s_addr, sizeof(dnsquestion->ids.origRemote.sin4.sin_addr.s_addr), 1, filepointer.get());
+ }
+ else if (dnsquestion->ids.origRemote.sin4.sin_family == AF_INET6) {
+ fwrite(&dnsquestion->ids.origRemote.sin6.sin6_addr.s6_addr, sizeof(dnsquestion->ids.origRemote.sin6.sin6_addr.s6_addr), 1, filepointer.get());
+ }
+ fwrite(&dnsquestion->ids.origRemote.sin4.sin_port, sizeof(dnsquestion->ids.origRemote.sin4.sin_port), 1, filepointer.get());
+ }
+ else {
+ if (d_includeTimestamp) {
+ fprintf(filepointer.get(), "[%llu.%lu] Packet from %s for %s %s with id %u\n", static_cast<unsigned long long>(dnsquestion->getQueryRealTime().tv_sec), static_cast<unsigned long>(dnsquestion->getQueryRealTime().tv_nsec), dnsquestion->ids.origRemote.toStringWithPort().c_str(), dnsquestion->ids.qname.toString().c_str(), QType(dnsquestion->ids.qtype).toString().c_str(), dnsquestion->getHeader()->id);
+ }
+ else {
+ fprintf(filepointer.get(), "Packet from %s for %s %s with id %u\n", dnsquestion->ids.origRemote.toStringWithPort().c_str(), dnsquestion->ids.qname.toString().c_str(), QType(dnsquestion->ids.qtype).toString().c_str(), dnsquestion->getHeader()->id);
+ }
+ }
+ }
+ return Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ if (!d_fname.empty()) {
+ return "log to " + d_fname;
+ }
+ return "log";
+ }
+
+ void reload() override
+ {
+ if (!reopenLogFile()) {
+ warnlog("Unable to open file '%s' for logging: %s", d_fname, stringerror());
+ }
+ }
+
+private:
+ bool reopenLogFile()
+ {
+ // we are using a naked pointer here because we don't want fclose to be called
+ // with a nullptr, which would happen if we constructor a shared_ptr with fclose
+ // as a custom deleter and nullptr as a FILE*
+ // NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
+ auto* nfp = fopen(d_fname.c_str(), d_append ? "a+" : "w");
+ if (nfp == nullptr) {
+ /* don't fall on our sword when reopening */
+ return false;
+ }
+
+ auto filepointer = std::shared_ptr<FILE>(nfp, fclose);
+ nfp = nullptr;
+
+ if (!d_buffered) {
+ setbuf(filepointer.get(), nullptr);
+ }
+
+ std::atomic_store_explicit(&d_fp, std::move(filepointer), std::memory_order_release);
+ return true;
+ }
+
+ std::string d_fname;
+ std::shared_ptr<FILE> d_fp{nullptr};
+ bool d_binary{true};
+ bool d_verboseOnly{true};
+ bool d_includeTimestamp{false};
+ bool d_append{false};
+ bool d_buffered{true};
+};
+
+class LogResponseAction : public DNSResponseAction, public boost::noncopyable
+{
+public:
+ LogResponseAction() = default;
+
+ LogResponseAction(const std::string& str, bool append = false, bool buffered = true, bool verboseOnly = true, bool includeTimestamp = false) :
+ d_fname(str), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp), d_append(append), d_buffered(buffered)
+ {
+ if (str.empty()) {
+ return;
+ }
+
+ if (!reopenLogFile()) {
+ throw std::runtime_error("Unable to open file '" + str + "' for logging: " + stringerror());
+ }
+ }
+
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ auto filepointer = std::atomic_load_explicit(&d_fp, std::memory_order_acquire);
+ if (!filepointer) {
+ if (!d_verboseOnly || dnsdist::configuration::getCurrentRuntimeConfiguration().d_verbose) {
+ if (d_includeTimestamp) {
+ infolog("[%u.%u] Answer to %s for %s %s (%s) with id %u", static_cast<unsigned long long>(response->getQueryRealTime().tv_sec), static_cast<unsigned long>(response->getQueryRealTime().tv_nsec), response->ids.origRemote.toStringWithPort(), response->ids.qname.toString(), QType(response->ids.qtype).toString(), RCode::to_s(response->getHeader()->rcode), response->getHeader()->id);
+ }
+ else {
+ infolog("Answer to %s for %s %s (%s) with id %u", response->ids.origRemote.toStringWithPort(), response->ids.qname.toString(), QType(response->ids.qtype).toString(), RCode::to_s(response->getHeader()->rcode), response->getHeader()->id);
+ }
+ }
+ }
+ else {
+ if (d_includeTimestamp) {
+ fprintf(filepointer.get(), "[%llu.%lu] Answer to %s for %s %s (%s) with id %u\n", static_cast<unsigned long long>(response->getQueryRealTime().tv_sec), static_cast<unsigned long>(response->getQueryRealTime().tv_nsec), response->ids.origRemote.toStringWithPort().c_str(), response->ids.qname.toString().c_str(), QType(response->ids.qtype).toString().c_str(), RCode::to_s(response->getHeader()->rcode).c_str(), response->getHeader()->id);
+ }
+ else {
+ fprintf(filepointer.get(), "Answer to %s for %s %s (%s) with id %u\n", response->ids.origRemote.toStringWithPort().c_str(), response->ids.qname.toString().c_str(), QType(response->ids.qtype).toString().c_str(), RCode::to_s(response->getHeader()->rcode).c_str(), response->getHeader()->id);
+ }
+ }
+ return Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ if (!d_fname.empty()) {
+ return "log to " + d_fname;
+ }
+ return "log";
+ }
+
+ void reload() override
+ {
+ if (!reopenLogFile()) {
+ warnlog("Unable to open file '%s' for logging: %s", d_fname, stringerror());
+ }
+ }
+
+private:
+ bool reopenLogFile()
+ {
+ // we are using a naked pointer here because we don't want fclose to be called
+ // with a nullptr, which would happen if we constructor a shared_ptr with fclose
+ // as a custom deleter and nullptr as a FILE*
+ // NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
+ auto* nfp = fopen(d_fname.c_str(), d_append ? "a+" : "w");
+ if (nfp == nullptr) {
+ /* don't fall on our sword when reopening */
+ return false;
+ }
+
+ auto filepointer = std::shared_ptr<FILE>(nfp, fclose);
+ nfp = nullptr;
+
+ if (!d_buffered) {
+ setbuf(filepointer.get(), nullptr);
+ }
+
+ std::atomic_store_explicit(&d_fp, std::move(filepointer), std::memory_order_release);
+ return true;
+ }
+
+ std::string d_fname;
+ std::shared_ptr<FILE> d_fp{nullptr};
+ bool d_verboseOnly{true};
+ bool d_includeTimestamp{false};
+ bool d_append{false};
+ bool d_buffered{true};
+};
+
+class SetDisableValidationAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) {
+ header.cd = true;
+ return true;
+ });
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "set cd=1";
+ }
+};
+
+class SetSkipCacheAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ dnsquestion->ids.skipCache = true;
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "skip cache";
+ }
+};
+
+class SetSkipCacheResponseAction : public DNSResponseAction
+{
+public:
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ response->ids.skipCache = true;
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "skip cache";
+ }
+};
+
+class SetTempFailureCacheTTLAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ SetTempFailureCacheTTLAction(uint32_t ttl) :
+ d_ttl(ttl)
+ {
+ }
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ dnsquestion->ids.tempFailureTTL = d_ttl;
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "set tempfailure cache ttl to " + std::to_string(d_ttl);
+ }
+
+private:
+ uint32_t d_ttl;
+};
+
+class SetECSPrefixLengthAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ SetECSPrefixLengthAction(uint16_t v4Length, uint16_t v6Length) :
+ d_v4PrefixLength(v4Length), d_v6PrefixLength(v6Length)
+ {
+ }
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ dnsquestion->ecsPrefixLength = dnsquestion->ids.origRemote.sin4.sin_family == AF_INET ? d_v4PrefixLength : d_v6PrefixLength;
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "set ECS prefix length to " + std::to_string(d_v4PrefixLength) + "/" + std::to_string(d_v6PrefixLength);
+ }
+
+private:
+ uint16_t d_v4PrefixLength;
+ uint16_t d_v6PrefixLength;
+};
+
+class SetECSOverrideAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ SetECSOverrideAction(bool ecsOverride) :
+ d_ecsOverride(ecsOverride)
+ {
+ }
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ dnsquestion->ecsOverride = d_ecsOverride;
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "set ECS override to " + std::to_string(static_cast<int>(d_ecsOverride));
+ }
+
+private:
+ bool d_ecsOverride;
+};
+
+class SetDisableECSAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ dnsquestion->useECS = false;
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "disable ECS";
+ }
+};
+
+class SetECSAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ SetECSAction(const Netmask& v4Netmask) :
+ d_v4(v4Netmask), d_hasV6(false)
+ {
+ }
+
+ SetECSAction(const Netmask& v4Netmask, const Netmask& v6Netmask) :
+ d_v4(v4Netmask), d_v6(v6Netmask), d_hasV6(true)
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ if (d_hasV6) {
+ dnsquestion->ecs = std::make_unique<Netmask>(dnsquestion->ids.origRemote.isIPv4() ? d_v4 : d_v6);
+ }
+ else {
+ dnsquestion->ecs = std::make_unique<Netmask>(d_v4);
+ }
+
+ return Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ std::string result = "set ECS to " + d_v4.toString();
+ if (d_hasV6) {
+ result += " / " + d_v6.toString();
+ }
+ return result;
+ }
+
+private:
+ Netmask d_v4;
+ Netmask d_v6;
+ bool d_hasV6;
+};
+
+#ifndef DISABLE_PROTOBUF
+static DnstapMessage::ProtocolType ProtocolToDNSTap(dnsdist::Protocol protocol)
+{
+ if (protocol == dnsdist::Protocol::DoUDP) {
+ return DnstapMessage::ProtocolType::DoUDP;
+ }
+ if (protocol == dnsdist::Protocol::DoTCP) {
+ return DnstapMessage::ProtocolType::DoTCP;
+ }
+ if (protocol == dnsdist::Protocol::DoT) {
+ return DnstapMessage::ProtocolType::DoT;
+ }
+ if (protocol == dnsdist::Protocol::DoH || protocol == dnsdist::Protocol::DoH3) {
+ return DnstapMessage::ProtocolType::DoH;
+ }
+ if (protocol == dnsdist::Protocol::DNSCryptUDP) {
+ return DnstapMessage::ProtocolType::DNSCryptUDP;
+ }
+ if (protocol == dnsdist::Protocol::DNSCryptTCP) {
+ return DnstapMessage::ProtocolType::DNSCryptTCP;
+ }
+ if (protocol == dnsdist::Protocol::DoQ) {
+ return DnstapMessage::ProtocolType::DoQ;
+ }
+ throw std::runtime_error("Unhandled protocol for dnstap: " + protocol.toPrettyString());
+}
+
+static void remoteLoggerQueueData(RemoteLoggerInterface& remoteLogger, const std::string& data)
+{
+ auto ret = remoteLogger.queueData(data);
+
+ switch (ret) {
+ case RemoteLoggerInterface::Result::Queued:
+ break;
+ case RemoteLoggerInterface::Result::PipeFull: {
+ vinfolog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret));
+ break;
+ }
+ case RemoteLoggerInterface::Result::TooLarge: {
+ warnlog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret));
+ break;
+ }
+ case RemoteLoggerInterface::Result::OtherError:
+ warnlog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret));
+ }
+}
+
+class DnstapLogAction : public DNSAction, public boost::noncopyable
+{
+public:
+ // this action does not stop the processing
+ DnstapLogAction(std::string identity, std::shared_ptr<RemoteLoggerInterface>& logger, std::optional<std::function<void(DNSQuestion*, DnstapMessage*)>> alterFunc) :
+ d_identity(std::move(identity)), d_logger(logger), d_alterFunc(std::move(alterFunc))
+ {
+ }
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ static thread_local std::string data;
+ data.clear();
+
+ DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(dnsquestion->getProtocol());
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+ DnstapMessage message(std::move(data), !dnsquestion->getHeader()->qr ? DnstapMessage::MessageType::client_query : DnstapMessage::MessageType::client_response, d_identity, &dnsquestion->ids.origRemote, &dnsquestion->ids.origDest, protocol, reinterpret_cast<const char*>(dnsquestion->getData().data()), dnsquestion->getData().size(), &dnsquestion->getQueryRealTime(), nullptr);
+ {
+ if (d_alterFunc) {
+ auto lock = g_lua.lock();
+ (*d_alterFunc)(dnsquestion, &message);
+ }
+ }
+
+ data = message.getBuffer();
+ remoteLoggerQueueData(*d_logger, data);
+
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "remote log as dnstap to " + (d_logger ? d_logger->toString() : "");
+ }
+
+private:
+ std::string d_identity;
+ std::shared_ptr<RemoteLoggerInterface> d_logger;
+ std::optional<std::function<void(DNSQuestion*, DnstapMessage*)>> d_alterFunc;
+};
+
+namespace
+{
+ void addMetaDataToProtobuf(DNSDistProtoBufMessage& message, const DNSQuestion& dnsquestion, const std::vector<std::pair<std::string, ProtoBufMetaKey>>& metas)
+ {
+ for (const auto& [name, meta] : metas) {
+ message.addMeta(name, meta.getValues(dnsquestion), {});
+ }
+ }
+
+ void addTagsToProtobuf(DNSDistProtoBufMessage& message, const DNSQuestion& dnsquestion, const std::unordered_set<std::string>& allowed)
+ {
+ if (!dnsquestion.ids.qTag) {
+ return;
+ }
+
+ for (const auto& [key, value] : *dnsquestion.ids.qTag) {
+ if (!allowed.empty() && allowed.count(key) == 0) {
+ continue;
+ }
+
+ if (value.empty()) {
+ message.addTag(key);
+ }
+ else {
+ auto tag = key;
+ tag.append(":");
+ tag.append(value);
+ message.addTag(tag);
+ }
+ }
+ }
+
+ void addExtendedDNSErrorToProtobuf(DNSDistProtoBufMessage& message, const DNSResponse& response, const std::string& metaKey)
+ {
+ auto [infoCode, extraText] = dnsdist::edns::getExtendedDNSError(response.getData());
+ if (!infoCode) {
+ return;
+ }
+
+ if (extraText) {
+ message.addMeta(metaKey, {*extraText}, {*infoCode});
+ }
+ else {
+ message.addMeta(metaKey, {}, {*infoCode});
+ }
+ }
+}
+
+class RemoteLogAction : public DNSAction, public boost::noncopyable
+{
+public:
+ // this action does not stop the processing
+ RemoteLogAction(RemoteLogActionConfiguration& config) :
+ d_tagsToExport(std::move(config.tagsToExport)), d_metas(std::move(config.metas)), d_logger(config.logger), d_alterFunc(std::move(config.alterQueryFunc)), d_serverID(config.serverID), d_ipEncryptKey(config.ipEncryptKey)
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ if (!dnsquestion->ids.d_protoBufData) {
+ dnsquestion->ids.d_protoBufData = std::make_unique<InternalQueryState::ProtoBufData>();
+ }
+ if (!dnsquestion->ids.d_protoBufData->uniqueId) {
+ dnsquestion->ids.d_protoBufData->uniqueId = getUniqueID();
+ }
+
+ DNSDistProtoBufMessage message(*dnsquestion);
+ if (!d_serverID.empty()) {
+ message.setServerIdentity(d_serverID);
+ }
+
+#ifdef HAVE_IPCIPHER
+ if (!d_ipEncryptKey.empty()) {
+ message.setRequestor(encryptCA(dnsquestion->ids.origRemote, d_ipEncryptKey));
+ }
+#endif /* HAVE_IPCIPHER */
+
+ if (d_tagsToExport) {
+ addTagsToProtobuf(message, *dnsquestion, *d_tagsToExport);
+ }
+
+ addMetaDataToProtobuf(message, *dnsquestion, d_metas);
+
+ if (d_alterFunc) {
+ auto lock = g_lua.lock();
+ (*d_alterFunc)(dnsquestion, &message);
+ }
+
+ static thread_local std::string data;
+ data.clear();
+ message.serialize(data);
+ remoteLoggerQueueData(*d_logger, data);
+
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "remote log to " + (d_logger ? d_logger->toString() : "");
+ }
+
+private:
+ std::optional<std::unordered_set<std::string>> d_tagsToExport;
+ std::vector<std::pair<std::string, ProtoBufMetaKey>> d_metas;
+ std::shared_ptr<RemoteLoggerInterface> d_logger;
+ std::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)>> d_alterFunc;
+ std::string d_serverID;
+ std::string d_ipEncryptKey;
+};
+
+#endif /* DISABLE_PROTOBUF */
+
+class SNMPTrapAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ SNMPTrapAction(std::string reason) :
+ d_reason(std::move(reason))
+ {
+ }
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ if (g_snmpAgent != nullptr && dnsdist::configuration::getImmutableConfiguration().d_snmpTrapsEnabled) {
+ g_snmpAgent->sendDNSTrap(*dnsquestion, d_reason);
+ }
+
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "send SNMP trap";
+ }
+
+private:
+ std::string d_reason;
+};
+
+class SetTagAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ SetTagAction(std::string tag, std::string value) :
+ d_tag(std::move(tag)), d_value(std::move(value))
+ {
+ }
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ dnsquestion->setTag(d_tag, d_value);
+
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "set tag '" + d_tag + "' to value '" + d_value + "'";
+ }
+
+private:
+ std::string d_tag;
+ std::string d_value;
+};
+
+#ifndef DISABLE_PROTOBUF
+class DnstapLogResponseAction : public DNSResponseAction, public boost::noncopyable
+{
+public:
+ // this action does not stop the processing
+ DnstapLogResponseAction(std::string identity, std::shared_ptr<RemoteLoggerInterface>& logger, std::optional<std::function<void(DNSResponse*, DnstapMessage*)>> alterFunc) :
+ d_identity(std::move(identity)), d_logger(logger), d_alterFunc(std::move(alterFunc))
+ {
+ }
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ static thread_local std::string data;
+ struct timespec now = {};
+ gettime(&now, true);
+ data.clear();
+
+ DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(response->getProtocol());
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+ DnstapMessage message(std::move(data), DnstapMessage::MessageType::client_response, d_identity, &response->ids.origRemote, &response->ids.origDest, protocol, reinterpret_cast<const char*>(response->getData().data()), response->getData().size(), &response->getQueryRealTime(), &now);
+ {
+ if (d_alterFunc) {
+ auto lock = g_lua.lock();
+ (*d_alterFunc)(response, &message);
+ }
+ }
+
+ data = message.getBuffer();
+ remoteLoggerQueueData(*d_logger, data);
+
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "log response as dnstap to " + (d_logger ? d_logger->toString() : "");
+ }
+
+private:
+ std::string d_identity;
+ std::shared_ptr<RemoteLoggerInterface> d_logger;
+ std::optional<std::function<void(DNSResponse*, DnstapMessage*)>> d_alterFunc;
+};
+
+class RemoteLogResponseAction : public DNSResponseAction, public boost::noncopyable
+{
+public:
+ // this action does not stop the processing
+ RemoteLogResponseAction(RemoteLogActionConfiguration& config) :
+ d_tagsToExport(std::move(config.tagsToExport)), d_metas(std::move(config.metas)), d_logger(config.logger), d_alterFunc(std::move(config.alterResponseFunc)), d_serverID(config.serverID), d_ipEncryptKey(config.ipEncryptKey), d_exportExtendedErrorsToMeta(std::move(config.exportExtendedErrorsToMeta)), d_includeCNAME(config.includeCNAME)
+ {
+ }
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ if (!response->ids.d_protoBufData) {
+ response->ids.d_protoBufData = std::make_unique<InternalQueryState::ProtoBufData>();
+ }
+ if (!response->ids.d_protoBufData->uniqueId) {
+ response->ids.d_protoBufData->uniqueId = getUniqueID();
+ }
+
+ DNSDistProtoBufMessage message(*response, d_includeCNAME);
+ if (!d_serverID.empty()) {
+ message.setServerIdentity(d_serverID);
+ }
+
+#ifdef HAVE_IPCIPHER
+ if (!d_ipEncryptKey.empty()) {
+ message.setRequestor(encryptCA(response->ids.origRemote, d_ipEncryptKey));
+ }
+#endif /* HAVE_IPCIPHER */
+
+ if (d_tagsToExport) {
+ addTagsToProtobuf(message, *response, *d_tagsToExport);
+ }
+
+ addMetaDataToProtobuf(message, *response, d_metas);
+
+ if (d_exportExtendedErrorsToMeta) {
+ addExtendedDNSErrorToProtobuf(message, *response, *d_exportExtendedErrorsToMeta);
+ }
+
+ if (d_alterFunc) {
+ auto lock = g_lua.lock();
+ (*d_alterFunc)(response, &message);
+ }
+
+ static thread_local std::string data;
+ data.clear();
+ message.serialize(data);
+ d_logger->queueData(data);
+
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "remote log response to " + (d_logger ? d_logger->toString() : "");
+ }
+
+private:
+ std::optional<std::unordered_set<std::string>> d_tagsToExport;
+ std::vector<std::pair<std::string, ProtoBufMetaKey>> d_metas;
+ std::shared_ptr<RemoteLoggerInterface> d_logger;
+ std::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)>> d_alterFunc;
+ std::string d_serverID;
+ std::string d_ipEncryptKey;
+ std::optional<std::string> d_exportExtendedErrorsToMeta{std::nullopt};
+ bool d_includeCNAME;
+};
+
+#endif /* DISABLE_PROTOBUF */
+
+class DropResponseAction : public DNSResponseAction
+{
+public:
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ return Action::Drop;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "drop";
+ }
+};
+
+class AllowResponseAction : public DNSResponseAction
+{
+public:
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ return Action::Allow;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "allow";
+ }
+};
+
+class DelayResponseAction : public DNSResponseAction
+{
+public:
+ DelayResponseAction(int msec) :
+ d_msec(msec)
+ {
+ }
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ *ruleresult = std::to_string(d_msec);
+ return Action::Delay;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "delay by " + std::to_string(d_msec) + " ms";
+ }
+
+private:
+ int d_msec;
+};
+
+class SNMPTrapResponseAction : public DNSResponseAction
+{
+public:
+ // this action does not stop the processing
+ SNMPTrapResponseAction(std::string reason) :
+ d_reason(std::move(reason))
+ {
+ }
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ if (g_snmpAgent != nullptr && dnsdist::configuration::getImmutableConfiguration().d_snmpTrapsEnabled) {
+ g_snmpAgent->sendDNSTrap(*response, d_reason);
+ }
+
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "send SNMP trap";
+ }
+
+private:
+ std::string d_reason;
+};
+
+class SetTagResponseAction : public DNSResponseAction
+{
+public:
+ // this action does not stop the processing
+ SetTagResponseAction(std::string tag, std::string value) :
+ d_tag(std::move(tag)), d_value(std::move(value))
+ {
+ }
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ response->setTag(d_tag, d_value);
+
+ return Action::None;
+ }
+ [[nodiscard]] std::string toString() const override
+ {
+ return "set tag '" + d_tag + "' to value '" + d_value + "'";
+ }
+
+private:
+ std::string d_tag;
+ std::string d_value;
+};
+
+class ClearRecordTypesResponseAction : public DNSResponseAction, public boost::noncopyable
+{
+public:
+ ClearRecordTypesResponseAction(std::unordered_set<QType> qtypes) :
+ d_qtypes(std::move(qtypes))
+ {
+ }
+
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ if (!d_qtypes.empty()) {
+ clearDNSPacketRecordTypes(response->getMutableData(), d_qtypes);
+ }
+ return DNSResponseAction::Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "clear record types";
+ }
+
+private:
+ std::unordered_set<QType> d_qtypes{};
+};
+
+class ContinueAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ ContinueAction(std::shared_ptr<DNSAction>& action) :
+ d_action(action)
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ if (d_action) {
+ /* call the action */
+ auto action = (*d_action)(dnsquestion, ruleresult);
+ bool drop = false;
+ /* apply the changes if needed (pool selection, flags, etc */
+ processRulesResult(action, *dnsquestion, *ruleresult, drop);
+ }
+
+ /* but ignore the resulting action no matter what */
+ return Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ if (d_action) {
+ return "continue after: " + (d_action ? d_action->toString() : "");
+ }
+ return "no op";
+ }
+
+private:
+ std::shared_ptr<DNSAction> d_action;
+};
+
+#if defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_HTTP3)
+class HTTPStatusAction : public DNSAction
+{
+public:
+ HTTPStatusAction(int code, PacketBuffer body, std::string contentType, dnsdist::ResponseConfig responseConfig) :
+ d_responseConfig(responseConfig), d_body(std::move(body)), d_contentType(std::move(contentType)), d_code(code)
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+#if defined(HAVE_DNS_OVER_HTTPS)
+ if (dnsquestion->ids.du) {
+ dnsquestion->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
+ header.qr = true; // for good measure
+ setResponseHeadersFromConfig(header, d_responseConfig);
+ return true;
+ });
+ return Action::HeaderModify;
+ }
+#endif /* HAVE_DNS_OVER_HTTPS */
+#if defined(HAVE_DNS_OVER_HTTP3)
+ if (dnsquestion->ids.doh3u) {
+ dnsquestion->ids.doh3u->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
+ header.qr = true; // for good measure
+ setResponseHeadersFromConfig(header, d_responseConfig);
+ return true;
+ });
+ return Action::HeaderModify;
+ }
+#endif /* HAVE_DNS_OVER_HTTP3 */
+ return Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "return an HTTP status of " + std::to_string(d_code);
+ }
+
+private:
+ dnsdist::ResponseConfig d_responseConfig;
+ PacketBuffer d_body;
+ std::string d_contentType;
+ int d_code;
+};
+#endif /* HAVE_DNS_OVER_HTTPS || HAVE_DNS_OVER_HTTP3 */
+
+#if defined(HAVE_LMDB) || defined(HAVE_CDB)
+class KeyValueStoreLookupAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ KeyValueStoreLookupAction(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, std::string destinationTag) :
+ d_kvs(kvs), d_key(lookupKey), d_tag(std::move(destinationTag))
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ std::vector<std::string> keys = d_key->getKeys(*dnsquestion);
+ std::string result;
+ for (const auto& key : keys) {
+ if (d_kvs->getValue(key, result)) {
+ break;
+ }
+ }
+
+ dnsquestion->setTag(d_tag, std::move(result));
+
+ return Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "lookup key-value store based on '" + d_key->toString() + "' and set the result in tag '" + d_tag + "'";
+ }
+
+private:
+ std::shared_ptr<KeyValueStore> d_kvs;
+ std::shared_ptr<KeyValueLookupKey> d_key;
+ std::string d_tag;
+};
+
+class KeyValueStoreRangeLookupAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ KeyValueStoreRangeLookupAction(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, std::string destinationTag) :
+ d_kvs(kvs), d_key(lookupKey), d_tag(std::move(destinationTag))
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ std::vector<std::string> keys = d_key->getKeys(*dnsquestion);
+ std::string result;
+ for (const auto& key : keys) {
+ if (d_kvs->getRangeValue(key, result)) {
+ break;
+ }
+ }
+
+ dnsquestion->setTag(d_tag, std::move(result));
+
+ return Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "do a range-based lookup in key-value store based on '" + d_key->toString() + "' and set the result in tag '" + d_tag + "'";
+ }
+
+private:
+ std::shared_ptr<KeyValueStore> d_kvs;
+ std::shared_ptr<KeyValueLookupKey> d_key;
+ std::string d_tag;
+};
+#endif /* defined(HAVE_LMDB) || defined(HAVE_CDB) */
+
+class SetMaxReturnedTTLAction : public DNSAction
+{
+public:
+ SetMaxReturnedTTLAction(uint32_t cap) :
+ d_cap(cap)
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ dnsquestion->ids.ttlCap = d_cap;
+ return DNSAction::Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "cap the TTL of the returned response to " + std::to_string(d_cap);
+ }
+
+private:
+ uint32_t d_cap;
+};
+
+class SetMaxReturnedTTLResponseAction : public DNSResponseAction
+{
+public:
+ SetMaxReturnedTTLResponseAction(uint32_t cap) :
+ d_cap(cap)
+ {
+ }
+
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ response->ids.ttlCap = d_cap;
+ return DNSResponseAction::Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "cap the TTL of the returned response to " + std::to_string(d_cap);
+ }
+
+private:
+ uint32_t d_cap;
+};
+
+class NegativeAndSOAAction : public DNSAction
+{
+public:
+ NegativeAndSOAAction(bool nxd, DNSName zone, uint32_t ttl, DNSName mname, DNSName rname, dnsdist::actions::SOAParams params, bool soaInAuthoritySection, dnsdist::ResponseConfig responseConfig) :
+ d_responseConfig(responseConfig), d_zone(std::move(zone)), d_mname(std::move(mname)), d_rname(std::move(rname)), d_ttl(ttl), d_params(params), d_nxd(nxd), d_soaInAuthoritySection(soaInAuthoritySection)
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ if (!setNegativeAndAdditionalSOA(*dnsquestion, d_nxd, d_zone, d_ttl, d_mname, d_rname, d_params.serial, d_params.refresh, d_params.retry, d_params.expire, d_params.minimum, d_soaInAuthoritySection)) {
+ return Action::None;
+ }
+
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
+ setResponseHeadersFromConfig(header, d_responseConfig);
+ return true;
+ });
+
+ return Action::Allow;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return std::string(d_nxd ? "NXD" : "NODATA") + " with SOA";
+ }
+
+private:
+ dnsdist::ResponseConfig d_responseConfig;
+ DNSName d_zone;
+ DNSName d_mname;
+ DNSName d_rname;
+ uint32_t d_ttl;
+ dnsdist::actions::SOAParams d_params;
+ bool d_nxd;
+ bool d_soaInAuthoritySection;
+};
+
+class SetProxyProtocolValuesAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ SetProxyProtocolValuesAction(const std::vector<std::pair<uint8_t, std::string>>& values)
+ {
+ d_values.reserve(values.size());
+ for (const auto& value : values) {
+ d_values.push_back({value.second, value.first});
+ }
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ if (!dnsquestion->proxyProtocolValues) {
+ dnsquestion->proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>();
+ }
+
+ *(dnsquestion->proxyProtocolValues) = d_values;
+
+ return Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "set Proxy-Protocol values";
+ }
+
+private:
+ std::vector<ProxyProtocolValue> d_values;
+};
+
+class SetAdditionalProxyProtocolValueAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ SetAdditionalProxyProtocolValueAction(uint8_t type, std::string value) :
+ d_value(std::move(value)), d_type(type)
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
+ {
+ if (!dnsquestion->proxyProtocolValues) {
+ dnsquestion->proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>();
+ }
+
+ dnsquestion->proxyProtocolValues->push_back({d_value, d_type});
+
+ return Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "add a Proxy-Protocol value of type " + std::to_string(d_type);
+ }
+
+private:
+ std::string d_value;
+ uint8_t d_type;
+};
+
+class SetReducedTTLResponseAction : public DNSResponseAction, public boost::noncopyable
+{
+public:
+ // this action does not stop the processing
+ SetReducedTTLResponseAction(uint8_t percentage) :
+ d_ratio(percentage / 100.0)
+ {
+ }
+
+ DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
+ {
+ // NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
+ auto visitor = [&](uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl) {
+ return ttl * d_ratio;
+ };
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+ editDNSPacketTTL(reinterpret_cast<char*>(response->getMutableData().data()), response->getData().size(), visitor);
+ return DNSResponseAction::Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "reduce ttl to " + std::to_string(d_ratio * 100) + " percent of its value";
+ }
+
+private:
+ double d_ratio{1.0};
+};
+
+class SetExtendedDNSErrorAction : public DNSAction
+{
+public:
+ // this action does not stop the processing
+ SetExtendedDNSErrorAction(uint16_t infoCode, const std::string& extraText)
+ {
+ d_ede.infoCode = infoCode;
+ d_ede.extraText = extraText;
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dnsQuestion, std::string* ruleresult) const override
+ {
+ dnsQuestion->ids.d_extendedError = std::make_unique<EDNSExtendedError>(d_ede);
+
+ return DNSAction::Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "set EDNS Extended DNS Error to " + std::to_string(d_ede.infoCode) + (d_ede.extraText.empty() ? std::string() : std::string(": \"") + d_ede.extraText + std::string("\""));
+ }
+
+private:
+ EDNSExtendedError d_ede;
+};
+
+class SetExtendedDNSErrorResponseAction : public DNSResponseAction
+{
+public:
+ // this action does not stop the processing
+ SetExtendedDNSErrorResponseAction(uint16_t infoCode, const std::string& extraText)
+ {
+ d_ede.infoCode = infoCode;
+ d_ede.extraText = extraText;
+ }
+
+ DNSResponseAction::Action operator()(DNSResponse* dnsResponse, std::string* ruleresult) const override
+ {
+ dnsResponse->ids.d_extendedError = std::make_unique<EDNSExtendedError>(d_ede);
+
+ return DNSResponseAction::Action::None;
+ }
+
+ [[nodiscard]] std::string toString() const override
+ {
+ return "set EDNS Extended DNS Error to " + std::to_string(d_ede.infoCode) + (d_ede.extraText.empty() ? std::string() : std::string(": \"") + d_ede.extraText + std::string("\""));
+ }
+
+private:
+ EDNSExtendedError d_ede;
+};
+
+class LimitTTLResponseAction : public DNSResponseAction, public boost::noncopyable
+{
+public:
+ LimitTTLResponseAction() {}
+
+ LimitTTLResponseAction(uint32_t min, uint32_t max = std::numeric_limits<uint32_t>::max(), const std::unordered_set<QType>& types = {}) :
+ d_types(types), d_min(min), d_max(max)
+ {
+ }
+
+ DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override
+ {
+ dnsdist::PacketMangling::restrictDNSPacketTTLs(dr->getMutableData(), d_min, d_max, d_types);
+ return DNSResponseAction::Action::None;
+ }
+
+ std::string toString() const override
+ {
+ std::string result = "limit ttl (" + std::to_string(d_min) + " <= ttl <= " + std::to_string(d_max);
+ if (!d_types.empty()) {
+ bool first = true;
+ result += ", types in [";
+ for (const auto& type : d_types) {
+ if (first) {
+ first = false;
+ }
+ else {
+ result += " ";
+ }
+ result += type.toString();
+ }
+ result += "]";
+ }
+ result += +")";
+ return result;
+ }
+
+private:
+ std::unordered_set<QType> d_types;
+ uint32_t d_min{0};
+ uint32_t d_max{std::numeric_limits<uint32_t>::max()};
+};
+
+#ifndef DISABLE_PROTOBUF
+std::shared_ptr<DNSAction> getRemoteLogAction(RemoteLogActionConfiguration& config)
+{
+ return std::shared_ptr<DNSAction>(new RemoteLogAction(config));
+}
+
+std::shared_ptr<DNSResponseAction> getRemoteLogResponseAction(RemoteLogActionConfiguration& config)
+{
+ return std::shared_ptr<DNSResponseAction>(new RemoteLogResponseAction(config));
+}
+
+std::shared_ptr<DNSAction> getDnstapLogAction(const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, std::optional<DnstapAlterFunction> alterFunc)
+{
+ return std::shared_ptr<DNSAction>(new DnstapLogAction(identity, logger, alterFunc));
+}
+
+std::shared_ptr<DNSResponseAction> getDnstapLogResponseAction(const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, std::optional<DnstapAlterResponseFunction> alterFunc)
+{
+ return std::shared_ptr<DNSResponseAction>(new DnstapLogResponseAction(identity, logger, alterFunc));
+}
+#endif /* DISABLE_PROTOBUF */
+
+#if defined(HAVE_LMDB) || defined(HAVE_CDB)
+std::shared_ptr<DNSAction> getKeyValueStoreLookupAction(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, const std::string& destinationTag)
+{
+ return std::shared_ptr<DNSAction>(new KeyValueStoreLookupAction(kvs, lookupKey, destinationTag));
+}
+
+std::shared_ptr<DNSAction> getKeyValueStoreRangeLookupAction(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, const std::string& destinationTag)
+{
+ return std::shared_ptr<DNSAction>(new KeyValueStoreRangeLookupAction(kvs, lookupKey, destinationTag));
+}
+#endif /* defined(HAVE_LMDB) || defined(HAVE_CDB) */
+
+#ifdef HAVE_DNS_OVER_HTTPS
+std::shared_ptr<DNSAction> getHTTPStatusAction(uint16_t status, PacketBuffer&& body, const std::string& contentType, const dnsdist::ResponseConfig& responseConfig)
+{
+ return std::shared_ptr<DNSAction>(new HTTPStatusAction(status, body, contentType, responseConfig));
+}
+
+#endif
+
+std::shared_ptr<DNSResponseAction> getLimitTTLResponseAction(uint32_t min, uint32_t max, std::unordered_set<QType> types)
+{
+ return std::shared_ptr<DNSResponseAction>(new LimitTTLResponseAction(min, max, types));
+}
+
+std::shared_ptr<DNSResponseAction> getMinTTLResponseAction(uint32_t min)
+{
+ return std::shared_ptr<DNSResponseAction>(new LimitTTLResponseAction(min));
+}
+
+std::shared_ptr<DNSResponseAction> getClearRecordTypesResponseAction(std::unordered_set<QType> types)
+{
+ return std::shared_ptr<DNSResponseAction>(new ClearRecordTypesResponseAction(types));
+}
+
+std::shared_ptr<DNSAction> getContinueAction(std::shared_ptr<DNSAction> action)
+{
+ return std::shared_ptr<DNSAction>(new ContinueAction(action));
+}
+
+std::shared_ptr<DNSAction> getNegativeAndSOAAction(bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, const SOAParams& params, bool soaInAuthority, dnsdist::ResponseConfig responseConfig)
+{
+ return std::shared_ptr<DNSAction>(new NegativeAndSOAAction(nxd, zone, ttl, mname, rname, params, soaInAuthority, responseConfig));
+}
+
+std::shared_ptr<DNSAction> getRCodeAction(uint8_t rcode, const dnsdist::ResponseConfig& responseConfig)
+{
+ return std::shared_ptr<DNSAction>(new RCodeAction(rcode, responseConfig));
+}
+
+std::shared_ptr<DNSAction> getERCodeAction(uint8_t rcode, const dnsdist::ResponseConfig& responseConfig)
+{
+ return std::shared_ptr<DNSAction>(new ERCodeAction(rcode, responseConfig));
+}
+
+std::shared_ptr<DNSAction> getSetECSAction(const std::string& ipv4)
+{
+ return std::shared_ptr<DNSAction>(new SetECSAction(Netmask(ipv4)));
+}
+
+std::shared_ptr<DNSAction> getSetECSAction(const std::string& ipv4, const std::string& ipv6)
+{
+ return std::shared_ptr<DNSAction>(new SetECSAction(Netmask(ipv4), Netmask(ipv6)));
+}
+
+std::shared_ptr<DNSAction> getSpoofAction(const std::vector<ComboAddress>& addresses, const dnsdist::ResponseConfig& config)
+{
+ return std::shared_ptr<DNSAction>(new SpoofAction(addresses, config));
+}
+
+std::shared_ptr<DNSAction> getSpoofAction(const std::vector<std::string>& rawRDatas, std::optional<uint16_t> qtypeForAny, const dnsdist::ResponseConfig& config)
+{
+ return std::shared_ptr<DNSAction>(new SpoofAction(rawRDatas, qtypeForAny, config));
+}
+
+std::shared_ptr<DNSAction> getSpoofAction(const DNSName& cname, const dnsdist::ResponseConfig& config)
+{
+ return std::shared_ptr<DNSAction>(new SpoofAction(cname, config));
+}
+
+std::shared_ptr<DNSAction> getSpoofAction(const PacketBuffer& packet)
+{
+ return std::shared_ptr<DNSAction>(new SpoofAction(packet));
+}
+
+std::shared_ptr<DNSAction> getSpoofSVCAction(const std::vector<SVCRecordParameters>& parameters, const dnsdist::ResponseConfig& responseConfig)
+{
+ return std::shared_ptr<DNSAction>(new SpoofSVCAction(parameters, responseConfig));
+}
+
+std::shared_ptr<DNSAction> getSetMaxReturnedTTLAction(uint32_t max)
+{
+ return std::shared_ptr<DNSAction>(new SetMaxReturnedTTLAction(max));
+}
+
+std::shared_ptr<DNSResponseAction> getSetMaxReturnedTTLResponseAction(uint32_t max)
+{
+ return std::shared_ptr<DNSResponseAction>(new SetMaxReturnedTTLResponseAction(max));
+}
+
+std::shared_ptr<DNSResponseAction> getSetMaxTTLResponseAction(uint32_t max)
+{
+ return std::shared_ptr<DNSResponseAction>(new LimitTTLResponseAction(0, max));
+}
+
+std::shared_ptr<DNSAction> getSetProxyProtocolValuesAction(const std::vector<std::pair<uint8_t, std::string>>& values)
+{
+ return std::shared_ptr<DNSAction>(new SetProxyProtocolValuesAction(values));
+}
+
+std::shared_ptr<DNSAction> getTeeAction(const ComboAddress& rca, std::optional<ComboAddress> lca, bool addECS, bool addProxyProtocol)
+{
+ return std::shared_ptr<DNSAction>(new TeeAction(rca, lca, addECS, addProxyProtocol));
+}
+
+#include "dnsdist-actions-factory-generated.cc"
+#include "dnsdist-response-actions-factory-generated.cc"
+}
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include <cstdint>
+#include <map>
+#include <optional>
+#include <string>
+#include <unordered_set>
+
+struct DNSQuestion;
+struct DNSResponse;
+
+#include "dnsdist-actions.hh"
+#include "dnsdist-protobuf.hh"
+#include "dnsdist-svc.hh"
+#include "dnstap.hh"
+#include "iputils.hh"
+#include "noinitvector.hh"
+
+struct dnsdist_ffi_dnsquestion_t;
+struct dnsdist_ffi_dnsresponse_t;
+class RemoteLoggerInterface;
+class KeyValueStore;
+class KeyValueLookupKey;
+
+namespace dnsdist::actions
+{
+using LuaActionFunction = std::function<std::tuple<int, boost::optional<string>>(DNSQuestion* dnsquestion)>;
+using LuaResponseActionFunction = std::function<std::tuple<int, boost::optional<string>>(DNSResponse* response)>;
+using LuaActionFFIFunction = std::function<int(dnsdist_ffi_dnsquestion_t* dnsquestion)>;
+using LuaResponseActionFFIFunction = std::function<int(dnsdist_ffi_dnsresponse_t* dnsquestion)>;
+
+struct SOAParams
+{
+ uint32_t serial;
+ uint32_t refresh;
+ uint32_t retry;
+ uint32_t expire;
+ uint32_t minimum;
+};
+
+#include "dnsdist-actions-factory-generated.hh"
+#include "dnsdist-response-actions-factory-generated.hh"
+
+std::shared_ptr<DNSAction> getContinueAction(std::shared_ptr<DNSAction> action);
+#ifdef HAVE_DNS_OVER_HTTPS
+std::shared_ptr<DNSAction> getHTTPStatusAction(uint16_t status, PacketBuffer&& body, const std::string& contentType, const dnsdist::ResponseConfig& responseConfig);
+#endif
+std::shared_ptr<DNSAction> getNegativeAndSOAAction(bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, const SOAParams& params, bool soaInAuthority, dnsdist::ResponseConfig responseConfig);
+std::shared_ptr<DNSAction> getSetProxyProtocolValuesAction(const std::vector<std::pair<uint8_t, std::string>>& values);
+std::shared_ptr<DNSAction> getRCodeAction(uint8_t rcode, const dnsdist::ResponseConfig& responseConfig);
+std::shared_ptr<DNSAction> getERCodeAction(uint8_t rcode, const dnsdist::ResponseConfig& responseConfig);
+
+#if defined(HAVE_LMDB) || defined(HAVE_CDB)
+std::shared_ptr<DNSAction> getKeyValueStoreLookupAction(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, const std::string& destinationTag);
+std::shared_ptr<DNSAction> getKeyValueStoreRangeLookupAction(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, const std::string& destinationTag);
+#endif /* defined(HAVE_LMDB) || defined(HAVE_CDB) */
+
+std::shared_ptr<DNSAction> getSetECSAction(const std::string& ipv4);
+std::shared_ptr<DNSAction> getSetECSAction(const std::string& ipv4, const std::string& ipv6);
+std::shared_ptr<DNSAction> getSpoofAction(const std::vector<ComboAddress>& addresses, const dnsdist::ResponseConfig& config);
+std::shared_ptr<DNSAction> getSpoofAction(const std::vector<std::string>& rawRDatas, std::optional<uint16_t> qtypeForAny, const dnsdist::ResponseConfig& config);
+std::shared_ptr<DNSAction> getSpoofAction(const DNSName& cname, const dnsdist::ResponseConfig& config);
+std::shared_ptr<DNSAction> getSpoofAction(const PacketBuffer& packet);
+
+std::shared_ptr<DNSAction> getSpoofSVCAction(const std::vector<SVCRecordParameters>& parameters, const dnsdist::ResponseConfig& responseConfig);
+
+std::shared_ptr<DNSAction> getSetMaxReturnedTTLAction(uint32_t max);
+std::shared_ptr<DNSResponseAction> getLimitTTLResponseAction(uint32_t min, uint32_t max = std::numeric_limits<uint32_t>::max(), std::unordered_set<QType> types = {});
+std::shared_ptr<DNSResponseAction> getMinTTLResponseAction(uint32_t min);
+std::shared_ptr<DNSResponseAction> getSetMaxReturnedTTLResponseAction(uint32_t max);
+std::shared_ptr<DNSResponseAction> getSetMaxTTLResponseAction(uint32_t max);
+
+std::shared_ptr<DNSResponseAction> getClearRecordTypesResponseAction(std::unordered_set<QType> types);
+
+std::shared_ptr<DNSAction> getTeeAction(const ComboAddress& rca, std::optional<ComboAddress> lca, bool addECS, bool addProxyProtocol);
+
+#ifndef DISABLE_PROTOBUF
+using ProtobufAlterFunction = std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)>;
+using ProtobufAlterResponseFunction = std::function<void(DNSResponse*, DNSDistProtoBufMessage*)>;
+using DnstapAlterFunction = std::function<void(DNSQuestion*, DnstapMessage*)>;
+using DnstapAlterResponseFunction = std::function<void(DNSResponse*, DnstapMessage*)>;
+
+struct RemoteLogActionConfiguration
+{
+ std::vector<std::pair<std::string, ProtoBufMetaKey>> metas;
+ std::optional<std::unordered_set<std::string>> tagsToExport{std::nullopt};
+ std::optional<ProtobufAlterFunction> alterQueryFunc;
+ std::optional<ProtobufAlterResponseFunction> alterResponseFunc;
+ std::shared_ptr<RemoteLoggerInterface> logger;
+ std::string serverID;
+ std::string ipEncryptKey;
+ std::optional<std::string> exportExtendedErrorsToMeta{std::nullopt};
+ bool includeCNAME{false};
+};
+std::shared_ptr<DNSAction> getRemoteLogAction(RemoteLogActionConfiguration& config);
+std::shared_ptr<DNSResponseAction> getRemoteLogResponseAction(RemoteLogActionConfiguration& config);
+std::shared_ptr<DNSAction> getDnstapLogAction(const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, std::optional<DnstapAlterFunction> alterFunc);
+std::shared_ptr<DNSResponseAction> getDnstapLogResponseAction(const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, std::optional<DnstapAlterResponseFunction> alterFunc);
+#endif /* DISABLE_PROTOBUF */
+}
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/
#include "config.h"
-#include "threadname.hh"
#include "dnsdist.hh"
-#include "dnsdist-async.hh"
+#include "dnsdist-actions-factory.hh"
#include "dnsdist-dnsparser.hh"
-#include "dnsdist-ecs.hh"
-#include "dnsdist-edns.hh"
#include "dnsdist-lua.hh"
#include "dnsdist-lua-ffi.hh"
-#include "dnsdist-mac-address.hh"
#include "dnsdist-protobuf.hh"
-#include "dnsdist-proxy-protocol.hh"
-#include "dnsdist-kvs.hh"
#include "dnsdist-rule-chains.hh"
-#include "dnsdist-self-answers.hh"
-#include "dnsdist-snmp.hh"
-#include "dnsdist-svc.hh"
-
-#include "dnstap.hh"
-#include "dnswriter.hh"
-#include "ednsoptions.hh"
-#include "fstrm_logger.hh"
-#include "remote_logger.hh"
-#include "svc-records.hh"
-
-#include <boost/optional/optional_io.hpp>
-
-#include "ipcipher.hh"
-
-class DropAction : public DNSAction
-{
-public:
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- return Action::Drop;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "drop";
- }
-};
-
-class AllowAction : public DNSAction
-{
-public:
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- return Action::Allow;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "allow";
- }
-};
-
-class NoneAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "no op";
- }
-};
-
-class QPSAction : public DNSAction
-{
-public:
- QPSAction(int limit) :
- d_qps(QPSLimiter(limit, limit))
- {
- }
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- if (d_qps.lock()->check()) {
- return Action::None;
- }
- return Action::Drop;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "qps limit to " + std::to_string(d_qps.lock()->getRate());
- }
-
-private:
- mutable LockGuarded<QPSLimiter> d_qps;
-};
-
-class DelayAction : public DNSAction
-{
-public:
- DelayAction(int msec) :
- d_msec(msec)
- {
- }
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- *ruleresult = std::to_string(d_msec);
- return Action::Delay;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "delay by " + std::to_string(d_msec) + " ms";
- }
-
-private:
- int d_msec;
-};
-
-class LimitTTLResponseAction : public DNSResponseAction, public boost::noncopyable
-{
-public:
- LimitTTLResponseAction() {}
-
- LimitTTLResponseAction(uint32_t min, uint32_t max = std::numeric_limits<uint32_t>::max(), const std::unordered_set<QType>& types = {}) :
- d_types(types), d_min(min), d_max(max)
- {
- }
-
- DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override
- {
- dnsdist::PacketMangling::restrictDNSPacketTTLs(dr->getMutableData(), d_min, d_max, d_types);
- return DNSResponseAction::Action::None;
- }
-
- std::string toString() const override
- {
- std::string result = "limit ttl (" + std::to_string(d_min) + " <= ttl <= " + std::to_string(d_max);
- if (!d_types.empty()) {
- bool first = true;
- result += ", types in [";
- for (const auto& type : d_types) {
- if (first) {
- first = false;
- }
- else {
- result += " ";
- }
- result += type.toString();
- }
- result += "]";
- }
- result += +")";
- return result;
- }
-
-private:
- std::unordered_set<QType> d_types;
- uint32_t d_min{0};
- uint32_t d_max{std::numeric_limits<uint32_t>::max()};
-};
-
-class TeeAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- TeeAction(const ComboAddress& rca, const boost::optional<ComboAddress>& lca, bool addECS = false, bool addProxyProtocol = false);
- TeeAction(TeeAction& other) = delete;
- TeeAction(TeeAction&& other) = delete;
- TeeAction& operator=(TeeAction& other) = delete;
- TeeAction& operator=(TeeAction&& other) = delete;
- ~TeeAction() override;
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override;
- [[nodiscard]] std::string toString() const override;
- std::map<std::string, double> getStats() const override;
-
-private:
- void worker();
-
- ComboAddress d_remote;
- std::thread d_worker;
- Socket d_socket;
- mutable std::atomic<unsigned long> d_senderrors{0};
- unsigned long d_recverrors{0};
- mutable std::atomic<unsigned long> d_queries{0};
- stat_t d_responses{0};
- stat_t d_nxdomains{0};
- stat_t d_servfails{0};
- stat_t d_refuseds{0};
- stat_t d_formerrs{0};
- stat_t d_notimps{0};
- stat_t d_noerrors{0};
- mutable stat_t d_tcpdrops{0};
- stat_t d_otherrcode{0};
- std::atomic<bool> d_pleaseQuit{false};
- bool d_addECS{false};
- bool d_addProxyProtocol{false};
-};
-
-TeeAction::TeeAction(const ComboAddress& rca, const boost::optional<ComboAddress>& lca, bool addECS, bool addProxyProtocol) :
- d_remote(rca), d_socket(d_remote.sin4.sin_family, SOCK_DGRAM, 0), d_addECS(addECS), d_addProxyProtocol(addProxyProtocol)
-{
- if (lca) {
- d_socket.bind(*lca, false);
- }
- d_socket.connect(d_remote);
- d_socket.setNonBlocking();
- d_worker = std::thread([this]() {
- worker();
- });
-}
-
-TeeAction::~TeeAction()
-{
- d_pleaseQuit = true;
- close(d_socket.releaseHandle());
- d_worker.join();
-}
-
-DNSAction::Action TeeAction::operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const
-{
- if (dnsquestion->overTCP()) {
- d_tcpdrops++;
- return DNSAction::Action::None;
- }
-
- d_queries++;
-
- PacketBuffer query;
- if (d_addECS) {
- query = dnsquestion->getData();
- bool ednsAdded = false;
- bool ecsAdded = false;
-
- std::string newECSOption;
- generateECSOption(dnsquestion->ecs ? dnsquestion->ecs->getNetwork() : dnsquestion->ids.origRemote, newECSOption, dnsquestion->ecs ? dnsquestion->ecs->getBits() : dnsquestion->ecsPrefixLength);
-
- if (!handleEDNSClientSubnet(query, dnsquestion->getMaximumSize(), dnsquestion->ids.qname.wirelength(), ednsAdded, ecsAdded, dnsquestion->ecsOverride, newECSOption)) {
- return DNSAction::Action::None;
- }
- }
-
- if (d_addProxyProtocol) {
- auto proxyPayload = getProxyProtocolPayload(*dnsquestion);
- if (query.empty()) {
- query = dnsquestion->getData();
- }
- if (!addProxyProtocol(query, proxyPayload)) {
- return DNSAction::Action::None;
- }
- }
-
- {
- const PacketBuffer& payload = query.empty() ? dnsquestion->getData() : query;
- auto res = send(d_socket.getHandle(), payload.data(), payload.size(), 0);
-
- if (res <= 0) {
- d_senderrors++;
- }
- }
-
- return DNSAction::Action::None;
-}
-
-std::string TeeAction::toString() const
-{
- return "tee to " + d_remote.toStringWithPort();
-}
-
-std::map<std::string, double> TeeAction::getStats() const
-{
- return {{"queries", d_queries},
- {"responses", d_responses},
- {"recv-errors", d_recverrors},
- {"send-errors", d_senderrors},
- {"noerrors", d_noerrors},
- {"nxdomains", d_nxdomains},
- {"refuseds", d_refuseds},
- {"servfails", d_servfails},
- {"other-rcode", d_otherrcode},
- {"tcp-drops", d_tcpdrops}};
-}
-
-void TeeAction::worker()
-{
- setThreadName("dnsdist/TeeWork");
- std::array<char, dnsdist::configuration::s_udpIncomingBufferSize> packet{};
- ssize_t res = 0;
- const dnsheader_aligned dnsheader(packet.data());
- for (;;) {
- res = waitForData(d_socket.getHandle(), 0, 250000);
- if (d_pleaseQuit) {
- break;
- }
-
- if (res < 0) {
- usleep(250000);
- continue;
- }
- if (res == 0) {
- continue;
- }
- res = recv(d_socket.getHandle(), packet.data(), packet.size(), 0);
- if (static_cast<size_t>(res) <= sizeof(struct dnsheader)) {
- d_recverrors++;
- }
- else {
- d_responses++;
- }
-
- // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well
- if (dnsheader->rcode == RCode::NoError) {
- d_noerrors++;
- }
- // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well
- else if (dnsheader->rcode == RCode::ServFail) {
- d_servfails++;
- }
- // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well
- else if (dnsheader->rcode == RCode::NXDomain) {
- d_nxdomains++;
- }
- // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well
- else if (dnsheader->rcode == RCode::Refused) {
- d_refuseds++;
- }
- // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well
- else if (dnsheader->rcode == RCode::FormErr) {
- d_formerrs++;
- }
- // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well
- else if (dnsheader->rcode == RCode::NotImp) {
- d_notimps++;
- }
- }
-}
-
-class PoolAction : public DNSAction
-{
-public:
- PoolAction(std::string pool, bool stopProcessing) :
- d_pool(std::move(pool)), d_stopProcessing(stopProcessing) {}
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- if (d_stopProcessing) {
- /* we need to do it that way to keep compatiblity with custom Lua actions returning DNSAction.Pool, 'poolname' */
- *ruleresult = d_pool;
- return Action::Pool;
- }
- dnsquestion->ids.poolName = d_pool;
- return Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "to pool " + d_pool;
- }
-
-private:
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
- const std::string d_pool;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
- const bool d_stopProcessing;
-};
-
-class QPSPoolAction : public DNSAction
-{
-public:
- QPSPoolAction(unsigned int limit, std::string pool, bool stopProcessing) :
- d_qps(QPSLimiter(limit, limit)), d_pool(std::move(pool)), d_stopProcessing(stopProcessing) {}
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- if (d_qps.lock()->check()) {
- if (d_stopProcessing) {
- /* we need to do it that way to keep compatiblity with custom Lua actions returning DNSAction.Pool, 'poolname' */
- *ruleresult = d_pool;
- return Action::Pool;
- }
- dnsquestion->ids.poolName = d_pool;
- }
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "max " + std::to_string(d_qps.lock()->getRate()) + " to pool " + d_pool;
- }
-
-private:
- mutable LockGuarded<QPSLimiter> d_qps;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
- const std::string d_pool;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
- const bool d_stopProcessing;
-};
-
-class RCodeAction : public DNSAction
-{
-public:
- RCodeAction(uint8_t rcode, dnsdist::ResponseConfig responseConfig) :
- d_responseConfig(responseConfig), d_rcode(rcode) {}
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
- header.rcode = d_rcode;
- header.qr = true; // for good measure
- setResponseHeadersFromConfig(header, d_responseConfig);
- return true;
- });
- return Action::HeaderModify;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "set rcode " + std::to_string(d_rcode);
- }
-
-private:
- dnsdist::ResponseConfig d_responseConfig;
- uint8_t d_rcode;
-};
-
-class ERCodeAction : public DNSAction
-{
-public:
- ERCodeAction(uint8_t rcode, dnsdist::ResponseConfig responseConfig) :
- d_responseConfig(responseConfig), d_rcode(rcode) {}
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
- header.rcode = (d_rcode & 0xF);
- header.qr = true; // for good measure
- setResponseHeadersFromConfig(header, d_responseConfig);
- return true;
- });
- dnsquestion->ednsRCode = ((d_rcode & 0xFFF0) >> 4);
- return Action::HeaderModify;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "set ercode " + ERCode::to_s(d_rcode);
- }
-
-private:
- dnsdist::ResponseConfig d_responseConfig;
- uint8_t d_rcode;
-};
-
-class SpoofSVCAction : public DNSAction
-{
-public:
- SpoofSVCAction(const LuaArray<SVCRecordParameters>& parameters, dnsdist::ResponseConfig responseConfig) :
- d_responseConfig(responseConfig)
- {
- d_payloads.reserve(parameters.size());
-
- for (const auto& param : parameters) {
- std::vector<uint8_t> payload;
- if (!generateSVCPayload(payload, param.second)) {
- throw std::runtime_error("Unable to generate a valid SVC record from the supplied parameters");
- }
-
- d_payloads.push_back(std::move(payload));
-
- for (const auto& hint : param.second.ipv4hints) {
- d_additionals4.insert({param.second.target, ComboAddress(hint)});
- }
-
- for (const auto& hint : param.second.ipv6hints) {
- d_additionals6.insert({param.second.target, ComboAddress(hint)});
- }
- }
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- if (!dnsdist::svc::generateSVCResponse(*dnsquestion, d_payloads, d_additionals4, d_additionals6, d_responseConfig)) {
- return Action::None;
- }
-
- return Action::HeaderModify;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "spoof SVC record ";
- }
-
-private:
- dnsdist::ResponseConfig d_responseConfig;
- std::vector<std::vector<uint8_t>> d_payloads{};
- std::set<std::pair<DNSName, ComboAddress>> d_additionals4{};
- std::set<std::pair<DNSName, ComboAddress>> d_additionals6{};
-};
-
-class TCAction : public DNSAction
-{
-public:
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- return Action::Truncate;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "tc=1 answer";
- }
-};
-
-class TCResponseAction : public DNSResponseAction
-{
-public:
- DNSResponseAction::Action operator()(DNSResponse* dnsResponse, std::string* ruleresult) const override
- {
- return Action::Truncate;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "tc=1 answer";
- }
-};
-
-class LuaAction : public DNSAction
-{
-public:
- using func_t = std::function<std::tuple<int, boost::optional<string>>(DNSQuestion* dnsquestion)>;
- LuaAction(LuaAction::func_t func) :
- d_func(std::move(func))
- {}
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- try {
- DNSAction::Action result{};
- {
- auto lock = g_lua.lock();
- auto ret = d_func(dnsquestion);
- if (ruleresult != nullptr) {
- if (boost::optional<std::string> rule = std::get<1>(ret)) {
- *ruleresult = *rule;
- }
- else {
- // default to empty string
- ruleresult->clear();
- }
- }
- result = static_cast<Action>(std::get<0>(ret));
- }
- dnsdist::handleQueuedAsynchronousEvents();
- return result;
- }
- catch (const std::exception& e) {
- warnlog("LuaAction failed inside Lua, returning ServFail: %s", e.what());
- }
- catch (...) {
- warnlog("LuaAction failed inside Lua, returning ServFail: [unknown exception]");
- }
- return DNSAction::Action::ServFail;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "Lua script";
- }
-
-private:
- func_t d_func;
-};
-
-class LuaResponseAction : public DNSResponseAction
-{
-public:
- using func_t = std::function<std::tuple<int, boost::optional<string>>(DNSResponse* response)>;
- LuaResponseAction(LuaResponseAction::func_t func) :
- d_func(std::move(func))
- {}
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- try {
- DNSResponseAction::Action result{};
- {
- auto lock = g_lua.lock();
- auto ret = d_func(response);
- if (ruleresult != nullptr) {
- if (boost::optional<std::string> rule = std::get<1>(ret)) {
- *ruleresult = *rule;
- }
- else {
- // default to empty string
- ruleresult->clear();
- }
- }
- result = static_cast<Action>(std::get<0>(ret));
- }
- dnsdist::handleQueuedAsynchronousEvents();
- return result;
- }
- catch (const std::exception& e) {
- warnlog("LuaResponseAction failed inside Lua, returning ServFail: %s", e.what());
- }
- catch (...) {
- warnlog("LuaResponseAction failed inside Lua, returning ServFail: [unknown exception]");
- }
- return DNSResponseAction::Action::ServFail;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "Lua response script";
- }
-
-private:
- func_t d_func;
-};
-
-class LuaFFIAction : public DNSAction
-{
-public:
- using func_t = std::function<int(dnsdist_ffi_dnsquestion_t* dnsquestion)>;
-
- LuaFFIAction(LuaFFIAction::func_t func) :
- d_func(std::move(func))
- {
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- dnsdist_ffi_dnsquestion_t dqffi(dnsquestion);
- try {
- DNSAction::Action result{};
- {
- auto lock = g_lua.lock();
- auto ret = d_func(&dqffi);
- if (ruleresult != nullptr) {
- if (dqffi.result) {
- *ruleresult = *dqffi.result;
- }
- else {
- // default to empty string
- ruleresult->clear();
- }
- }
- result = static_cast<DNSAction::Action>(ret);
- }
- dnsdist::handleQueuedAsynchronousEvents();
- return result;
- }
- catch (const std::exception& e) {
- warnlog("LuaFFIAction failed inside Lua, returning ServFail: %s", e.what());
- }
- catch (...) {
- warnlog("LuaFFIAction failed inside Lua, returning ServFail: [unknown exception]");
- }
- return DNSAction::Action::ServFail;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "Lua FFI script";
- }
-
-private:
- func_t d_func;
-};
-
-class LuaFFIPerThreadAction : public DNSAction
-{
-public:
- using func_t = std::function<int(dnsdist_ffi_dnsquestion_t* dnsquestion)>;
-
- LuaFFIPerThreadAction(std::string code) :
- d_functionCode(std::move(code)), d_functionID(s_functionsCounter++)
- {
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- try {
- auto& state = t_perThreadStates[d_functionID];
- if (!state.d_initialized) {
- setupLuaFFIPerThreadContext(state.d_luaContext);
- /* mark the state as initialized first so if there is a syntax error
- we only try to execute the code once */
- state.d_initialized = true;
- state.d_func = state.d_luaContext.executeCode<func_t>(d_functionCode);
- }
-
- if (!state.d_func) {
- /* the function was not properly initialized */
- return DNSAction::Action::None;
- }
-
- dnsdist_ffi_dnsquestion_t dqffi(dnsquestion);
- auto ret = state.d_func(&dqffi);
- if (ruleresult != nullptr) {
- if (dqffi.result) {
- *ruleresult = *dqffi.result;
- }
- else {
- // default to empty string
- ruleresult->clear();
- }
- }
- dnsdist::handleQueuedAsynchronousEvents();
- return static_cast<DNSAction::Action>(ret);
- }
- catch (const std::exception& e) {
- warnlog("LuaFFIPerThreadAction failed inside Lua, returning ServFail: %s", e.what());
- }
- catch (...) {
- warnlog("LuaFFIPerthreadAction failed inside Lua, returning ServFail: [unknown exception]");
- }
- return DNSAction::Action::ServFail;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "Lua FFI per-thread script";
- }
-
-private:
- struct PerThreadState
- {
- LuaContext d_luaContext;
- func_t d_func;
- bool d_initialized{false};
- };
- static std::atomic<uint64_t> s_functionsCounter;
- static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
- const std::string d_functionCode;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
- const uint64_t d_functionID;
-};
-
-std::atomic<uint64_t> LuaFFIPerThreadAction::s_functionsCounter = 0;
-thread_local std::map<uint64_t, LuaFFIPerThreadAction::PerThreadState> LuaFFIPerThreadAction::t_perThreadStates;
-
-class LuaFFIResponseAction : public DNSResponseAction
-{
-public:
- using func_t = std::function<int(dnsdist_ffi_dnsresponse_t* dnsquestion)>;
-
- LuaFFIResponseAction(LuaFFIResponseAction::func_t func) :
- d_func(std::move(func))
- {
- }
-
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- dnsdist_ffi_dnsresponse_t ffiResponse(response);
- try {
- DNSResponseAction::Action result{};
- {
- auto lock = g_lua.lock();
- auto ret = d_func(&ffiResponse);
- if (ruleresult != nullptr) {
- if (ffiResponse.result) {
- *ruleresult = *ffiResponse.result;
- }
- else {
- // default to empty string
- ruleresult->clear();
- }
- }
- result = static_cast<DNSResponseAction::Action>(ret);
- }
- dnsdist::handleQueuedAsynchronousEvents();
- return result;
- }
- catch (const std::exception& e) {
- warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: %s", e.what());
- }
- catch (...) {
- warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: [unknown exception]");
- }
- return DNSResponseAction::Action::ServFail;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "Lua FFI script";
- }
-
-private:
- func_t d_func;
-};
-
-class LuaFFIPerThreadResponseAction : public DNSResponseAction
-{
-public:
- using func_t = std::function<int(dnsdist_ffi_dnsresponse_t* response)>;
-
- LuaFFIPerThreadResponseAction(std::string code) :
- d_functionCode(std::move(code)), d_functionID(s_functionsCounter++)
- {
- }
-
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- try {
- auto& state = t_perThreadStates[d_functionID];
- if (!state.d_initialized) {
- setupLuaFFIPerThreadContext(state.d_luaContext);
- /* mark the state as initialized first so if there is a syntax error
- we only try to execute the code once */
- state.d_initialized = true;
- state.d_func = state.d_luaContext.executeCode<func_t>(d_functionCode);
- }
-
- if (!state.d_func) {
- /* the function was not properly initialized */
- return DNSResponseAction::Action::None;
- }
-
- dnsdist_ffi_dnsresponse_t ffiResponse(response);
- auto ret = state.d_func(&ffiResponse);
- if (ruleresult != nullptr) {
- if (ffiResponse.result) {
- *ruleresult = *ffiResponse.result;
- }
- else {
- // default to empty string
- ruleresult->clear();
- }
- }
- dnsdist::handleQueuedAsynchronousEvents();
- return static_cast<DNSResponseAction::Action>(ret);
- }
- catch (const std::exception& e) {
- warnlog("LuaFFIPerThreadResponseAction failed inside Lua, returning ServFail: %s", e.what());
- }
- catch (...) {
- warnlog("LuaFFIPerthreadResponseAction failed inside Lua, returning ServFail: [unknown exception]");
- }
- return DNSResponseAction::Action::ServFail;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "Lua FFI per-thread script";
- }
-
-private:
- struct PerThreadState
- {
- LuaContext d_luaContext;
- func_t d_func;
- bool d_initialized{false};
- };
-
- static std::atomic<uint64_t> s_functionsCounter;
- static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
- const std::string d_functionCode;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
- const uint64_t d_functionID;
-};
-
-std::atomic<uint64_t> LuaFFIPerThreadResponseAction::s_functionsCounter = 0;
-thread_local std::map<uint64_t, LuaFFIPerThreadResponseAction::PerThreadState> LuaFFIPerThreadResponseAction::t_perThreadStates;
-
-class SpoofAction : public DNSAction
-{
-public:
- SpoofAction(const vector<ComboAddress>& addrs, const dnsdist::ResponseConfig& responseConfig) :
- d_responseConfig(responseConfig), d_addrs(addrs)
- {
- for (const auto& addr : d_addrs) {
- if (addr.isIPv4()) {
- d_types.insert(QType::A);
- }
- else if (addr.isIPv6()) {
- d_types.insert(QType::AAAA);
- }
- }
-
- if (!d_addrs.empty()) {
- d_types.insert(QType::ANY);
- }
- }
-
- SpoofAction(const DNSName& cname, const dnsdist::ResponseConfig& responseConfig) :
- d_responseConfig(responseConfig), d_cname(cname)
- {
- }
-
- SpoofAction(const PacketBuffer& rawresponse) :
- d_raw(rawresponse)
- {
- }
-
- SpoofAction(const vector<std::string>& raws, std::optional<uint16_t> typeForAny, const dnsdist::ResponseConfig& responseConfig) :
- d_responseConfig(responseConfig), d_rawResponses(raws), d_rawTypeForAny(typeForAny)
- {
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, string* ruleresult) const override;
-
- string toString() const override
- {
- string ret = "spoof in ";
- if (!d_cname.empty()) {
- ret += d_cname.toString() + " ";
- }
- if (d_rawResponses.size() > 0) {
- ret += "raw bytes ";
- }
- else {
- for (const auto& a : d_addrs)
- ret += a.toString() + " ";
- }
- return ret;
- }
-
-private:
- dnsdist::ResponseConfig d_responseConfig;
- std::vector<ComboAddress> d_addrs;
- std::unordered_set<uint16_t> d_types;
- std::vector<std::string> d_rawResponses;
- PacketBuffer d_raw;
- DNSName d_cname;
- std::optional<uint16_t> d_rawTypeForAny{};
-};
-
-DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const
-{
- uint16_t qtype = dnsquestion->ids.qtype;
- // do we even have a response?
- if (d_cname.empty() && d_rawResponses.empty() &&
- // make sure pre-forged response is greater than sizeof(dnsheader)
- (d_raw.size() < sizeof(dnsheader)) && d_types.count(qtype) == 0) {
- return Action::None;
- }
-
- if (d_raw.size() >= sizeof(dnsheader)) {
- dnsdist::self_answers::generateAnswerFromRawPacket(*dnsquestion, d_raw);
- return Action::HeaderModify;
- }
-
- if (!d_cname.empty()) {
- if (dnsdist::self_answers::generateAnswerFromCNAME(*dnsquestion, d_cname, d_responseConfig)) {
- return Action::HeaderModify;
- }
- }
- else if (!d_rawResponses.empty()) {
- if (dnsdist::self_answers::generateAnswerFromRDataEntries(*dnsquestion, d_rawResponses, d_rawTypeForAny, d_responseConfig)) {
- return Action::HeaderModify;
- }
- }
- else {
- if (dnsdist::self_answers::generateAnswerFromIPAddresses(*dnsquestion, d_addrs, d_responseConfig)) {
- return Action::HeaderModify;
- }
- }
-
- return Action::None;
-}
-
-class SetMacAddrAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- SetMacAddrAction(uint16_t code) :
- d_code(code)
- {
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- dnsdist::MacAddress mac{};
- int res = dnsdist::MacAddressesCache::get(dnsquestion->ids.origRemote, mac.data(), mac.size());
- if (res != 0) {
- return Action::None;
- }
-
- std::string optRData;
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
- generateEDNSOption(d_code, reinterpret_cast<const char*>(mac.data()), optRData);
-
- if (dnsquestion->getHeader()->arcount > 0) {
- bool ednsAdded = false;
- bool optionAdded = false;
- PacketBuffer newContent;
- newContent.reserve(dnsquestion->getData().size());
-
- if (!slowRewriteEDNSOptionInQueryWithRecords(dnsquestion->getData(), newContent, ednsAdded, d_code, optionAdded, true, optRData)) {
- return Action::None;
- }
-
- if (newContent.size() > dnsquestion->getMaximumSize()) {
- return Action::None;
- }
-
- dnsquestion->getMutableData() = std::move(newContent);
- if (!dnsquestion->ids.ednsAdded && ednsAdded) {
- dnsquestion->ids.ednsAdded = true;
- }
-
- return Action::None;
- }
-
- auto& data = dnsquestion->getMutableData();
- if (generateOptRR(optRData, data, dnsquestion->getMaximumSize(), dnsdist::configuration::s_EdnsUDPPayloadSize, 0, false)) {
- dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) {
- header.arcount = htons(1);
- return true;
- });
- // make sure that any EDNS sent by the backend is removed before forwarding the response to the client
- dnsquestion->ids.ednsAdded = true;
- }
-
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "add EDNS MAC (code=" + std::to_string(d_code) + ")";
- }
-
-private:
- uint16_t d_code{3};
-};
-
-class SetEDNSOptionAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- SetEDNSOptionAction(uint16_t code, std::string data) :
- d_code(code), d_data(std::move(data))
- {
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- setEDNSOption(*dnsquestion, d_code, d_data);
- return Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "add EDNS Option (code=" + std::to_string(d_code) + ")";
- }
-
-private:
- uint16_t d_code;
- std::string d_data;
-};
-
-class SetNoRecurseAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) {
- header.rd = false;
- return true;
- });
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "set rd=0";
- }
-};
-
-class LogAction : public DNSAction, public boost::noncopyable
-{
-public:
- // this action does not stop the processing
- LogAction() = default;
-
- LogAction(const std::string& str, bool binary = true, bool append = false, bool buffered = true, bool verboseOnly = true, bool includeTimestamp = false) :
- d_fname(str), d_binary(binary), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp), d_append(append), d_buffered(buffered)
- {
- if (str.empty()) {
- return;
- }
-
- if (!reopenLogFile()) {
- throw std::runtime_error("Unable to open file '" + str + "' for logging: " + stringerror());
- }
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- auto filepointer = std::atomic_load_explicit(&d_fp, std::memory_order_acquire);
- if (!filepointer) {
- if (!d_verboseOnly || dnsdist::configuration::getCurrentRuntimeConfiguration().d_verbose) {
- if (d_includeTimestamp) {
- infolog("[%u.%u] Packet from %s for %s %s with id %d", static_cast<unsigned long long>(dnsquestion->getQueryRealTime().tv_sec), static_cast<unsigned long>(dnsquestion->getQueryRealTime().tv_nsec), dnsquestion->ids.origRemote.toStringWithPort(), dnsquestion->ids.qname.toString(), QType(dnsquestion->ids.qtype).toString(), dnsquestion->getHeader()->id);
- }
- else {
- infolog("Packet from %s for %s %s with id %d", dnsquestion->ids.origRemote.toStringWithPort(), dnsquestion->ids.qname.toString(), QType(dnsquestion->ids.qtype).toString(), dnsquestion->getHeader()->id);
- }
- }
- }
- else {
- if (d_binary) {
- const auto& out = dnsquestion->ids.qname.getStorage();
- if (d_includeTimestamp) {
- auto tv_sec = static_cast<uint64_t>(dnsquestion->getQueryRealTime().tv_sec);
- auto tv_nsec = static_cast<uint32_t>(dnsquestion->getQueryRealTime().tv_nsec);
- fwrite(&tv_sec, sizeof(tv_sec), 1, filepointer.get());
- fwrite(&tv_nsec, sizeof(tv_nsec), 1, filepointer.get());
- }
- uint16_t queryId = dnsquestion->getHeader()->id;
- fwrite(&queryId, sizeof(queryId), 1, filepointer.get());
- fwrite(out.c_str(), 1, out.size(), filepointer.get());
- fwrite(&dnsquestion->ids.qtype, sizeof(dnsquestion->ids.qtype), 1, filepointer.get());
- fwrite(&dnsquestion->ids.origRemote.sin4.sin_family, sizeof(dnsquestion->ids.origRemote.sin4.sin_family), 1, filepointer.get());
- if (dnsquestion->ids.origRemote.sin4.sin_family == AF_INET) {
- fwrite(&dnsquestion->ids.origRemote.sin4.sin_addr.s_addr, sizeof(dnsquestion->ids.origRemote.sin4.sin_addr.s_addr), 1, filepointer.get());
- }
- else if (dnsquestion->ids.origRemote.sin4.sin_family == AF_INET6) {
- fwrite(&dnsquestion->ids.origRemote.sin6.sin6_addr.s6_addr, sizeof(dnsquestion->ids.origRemote.sin6.sin6_addr.s6_addr), 1, filepointer.get());
- }
- fwrite(&dnsquestion->ids.origRemote.sin4.sin_port, sizeof(dnsquestion->ids.origRemote.sin4.sin_port), 1, filepointer.get());
- }
- else {
- if (d_includeTimestamp) {
- fprintf(filepointer.get(), "[%llu.%lu] Packet from %s for %s %s with id %u\n", static_cast<unsigned long long>(dnsquestion->getQueryRealTime().tv_sec), static_cast<unsigned long>(dnsquestion->getQueryRealTime().tv_nsec), dnsquestion->ids.origRemote.toStringWithPort().c_str(), dnsquestion->ids.qname.toString().c_str(), QType(dnsquestion->ids.qtype).toString().c_str(), dnsquestion->getHeader()->id);
- }
- else {
- fprintf(filepointer.get(), "Packet from %s for %s %s with id %u\n", dnsquestion->ids.origRemote.toStringWithPort().c_str(), dnsquestion->ids.qname.toString().c_str(), QType(dnsquestion->ids.qtype).toString().c_str(), dnsquestion->getHeader()->id);
- }
- }
- }
- return Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- if (!d_fname.empty()) {
- return "log to " + d_fname;
- }
- return "log";
- }
-
- void reload() override
- {
- if (!reopenLogFile()) {
- warnlog("Unable to open file '%s' for logging: %s", d_fname, stringerror());
- }
- }
-
-private:
- bool reopenLogFile()
- {
- // we are using a naked pointer here because we don't want fclose to be called
- // with a nullptr, which would happen if we constructor a shared_ptr with fclose
- // as a custom deleter and nullptr as a FILE*
- // NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
- auto* nfp = fopen(d_fname.c_str(), d_append ? "a+" : "w");
- if (nfp == nullptr) {
- /* don't fall on our sword when reopening */
- return false;
- }
-
- auto filepointer = std::shared_ptr<FILE>(nfp, fclose);
- nfp = nullptr;
-
- if (!d_buffered) {
- setbuf(filepointer.get(), nullptr);
- }
-
- std::atomic_store_explicit(&d_fp, std::move(filepointer), std::memory_order_release);
- return true;
- }
-
- std::string d_fname;
- std::shared_ptr<FILE> d_fp{nullptr};
- bool d_binary{true};
- bool d_verboseOnly{true};
- bool d_includeTimestamp{false};
- bool d_append{false};
- bool d_buffered{true};
-};
-
-class LogResponseAction : public DNSResponseAction, public boost::noncopyable
-{
-public:
- LogResponseAction() = default;
-
- LogResponseAction(const std::string& str, bool append = false, bool buffered = true, bool verboseOnly = true, bool includeTimestamp = false) :
- d_fname(str), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp), d_append(append), d_buffered(buffered)
- {
- if (str.empty()) {
- return;
- }
-
- if (!reopenLogFile()) {
- throw std::runtime_error("Unable to open file '" + str + "' for logging: " + stringerror());
- }
- }
-
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- auto filepointer = std::atomic_load_explicit(&d_fp, std::memory_order_acquire);
- if (!filepointer) {
- if (!d_verboseOnly || dnsdist::configuration::getCurrentRuntimeConfiguration().d_verbose) {
- if (d_includeTimestamp) {
- infolog("[%u.%u] Answer to %s for %s %s (%s) with id %u", static_cast<unsigned long long>(response->getQueryRealTime().tv_sec), static_cast<unsigned long>(response->getQueryRealTime().tv_nsec), response->ids.origRemote.toStringWithPort(), response->ids.qname.toString(), QType(response->ids.qtype).toString(), RCode::to_s(response->getHeader()->rcode), response->getHeader()->id);
- }
- else {
- infolog("Answer to %s for %s %s (%s) with id %u", response->ids.origRemote.toStringWithPort(), response->ids.qname.toString(), QType(response->ids.qtype).toString(), RCode::to_s(response->getHeader()->rcode), response->getHeader()->id);
- }
- }
- }
- else {
- if (d_includeTimestamp) {
- fprintf(filepointer.get(), "[%llu.%lu] Answer to %s for %s %s (%s) with id %u\n", static_cast<unsigned long long>(response->getQueryRealTime().tv_sec), static_cast<unsigned long>(response->getQueryRealTime().tv_nsec), response->ids.origRemote.toStringWithPort().c_str(), response->ids.qname.toString().c_str(), QType(response->ids.qtype).toString().c_str(), RCode::to_s(response->getHeader()->rcode).c_str(), response->getHeader()->id);
- }
- else {
- fprintf(filepointer.get(), "Answer to %s for %s %s (%s) with id %u\n", response->ids.origRemote.toStringWithPort().c_str(), response->ids.qname.toString().c_str(), QType(response->ids.qtype).toString().c_str(), RCode::to_s(response->getHeader()->rcode).c_str(), response->getHeader()->id);
- }
- }
- return Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- if (!d_fname.empty()) {
- return "log to " + d_fname;
- }
- return "log";
- }
-
- void reload() override
- {
- if (!reopenLogFile()) {
- warnlog("Unable to open file '%s' for logging: %s", d_fname, stringerror());
- }
- }
-
-private:
- bool reopenLogFile()
- {
- // we are using a naked pointer here because we don't want fclose to be called
- // with a nullptr, which would happen if we constructor a shared_ptr with fclose
- // as a custom deleter and nullptr as a FILE*
- // NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
- auto* nfp = fopen(d_fname.c_str(), d_append ? "a+" : "w");
- if (nfp == nullptr) {
- /* don't fall on our sword when reopening */
- return false;
- }
-
- auto filepointer = std::shared_ptr<FILE>(nfp, fclose);
- nfp = nullptr;
-
- if (!d_buffered) {
- setbuf(filepointer.get(), nullptr);
- }
-
- std::atomic_store_explicit(&d_fp, std::move(filepointer), std::memory_order_release);
- return true;
- }
-
- std::string d_fname;
- std::shared_ptr<FILE> d_fp{nullptr};
- bool d_verboseOnly{true};
- bool d_includeTimestamp{false};
- bool d_append{false};
- bool d_buffered{true};
-};
-
-class SetDisableValidationAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) {
- header.cd = true;
- return true;
- });
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "set cd=1";
- }
-};
-
-class SetSkipCacheAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- dnsquestion->ids.skipCache = true;
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "skip cache";
- }
-};
-
-class SetSkipCacheResponseAction : public DNSResponseAction
-{
-public:
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- response->ids.skipCache = true;
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "skip cache";
- }
-};
-
-class SetTempFailureCacheTTLAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- SetTempFailureCacheTTLAction(uint32_t ttl) :
- d_ttl(ttl)
- {
- }
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- dnsquestion->ids.tempFailureTTL = d_ttl;
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "set tempfailure cache ttl to " + std::to_string(d_ttl);
- }
-
-private:
- uint32_t d_ttl;
-};
-
-class SetECSPrefixLengthAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- SetECSPrefixLengthAction(uint16_t v4Length, uint16_t v6Length) :
- d_v4PrefixLength(v4Length), d_v6PrefixLength(v6Length)
- {
- }
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- dnsquestion->ecsPrefixLength = dnsquestion->ids.origRemote.sin4.sin_family == AF_INET ? d_v4PrefixLength : d_v6PrefixLength;
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "set ECS prefix length to " + std::to_string(d_v4PrefixLength) + "/" + std::to_string(d_v6PrefixLength);
- }
-
-private:
- uint16_t d_v4PrefixLength;
- uint16_t d_v6PrefixLength;
-};
-
-class SetECSOverrideAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- SetECSOverrideAction(bool ecsOverride) :
- d_ecsOverride(ecsOverride)
- {
- }
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- dnsquestion->ecsOverride = d_ecsOverride;
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "set ECS override to " + std::to_string(static_cast<int>(d_ecsOverride));
- }
-
-private:
- bool d_ecsOverride;
-};
-
-class SetDisableECSAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- dnsquestion->useECS = false;
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "disable ECS";
- }
-};
-
-class SetECSAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- SetECSAction(const Netmask& v4Netmask) :
- d_v4(v4Netmask), d_hasV6(false)
- {
- }
-
- SetECSAction(const Netmask& v4Netmask, const Netmask& v6Netmask) :
- d_v4(v4Netmask), d_v6(v6Netmask), d_hasV6(true)
- {
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- if (d_hasV6) {
- dnsquestion->ecs = std::make_unique<Netmask>(dnsquestion->ids.origRemote.isIPv4() ? d_v4 : d_v6);
- }
- else {
- dnsquestion->ecs = std::make_unique<Netmask>(d_v4);
- }
-
- return Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- std::string result = "set ECS to " + d_v4.toString();
- if (d_hasV6) {
- result += " / " + d_v6.toString();
- }
- return result;
- }
-
-private:
- Netmask d_v4;
- Netmask d_v6;
- bool d_hasV6;
-};
-
-#ifndef DISABLE_PROTOBUF
-static DnstapMessage::ProtocolType ProtocolToDNSTap(dnsdist::Protocol protocol)
-{
- if (protocol == dnsdist::Protocol::DoUDP) {
- return DnstapMessage::ProtocolType::DoUDP;
- }
- if (protocol == dnsdist::Protocol::DoTCP) {
- return DnstapMessage::ProtocolType::DoTCP;
- }
- if (protocol == dnsdist::Protocol::DoT) {
- return DnstapMessage::ProtocolType::DoT;
- }
- if (protocol == dnsdist::Protocol::DoH || protocol == dnsdist::Protocol::DoH3) {
- return DnstapMessage::ProtocolType::DoH;
- }
- if (protocol == dnsdist::Protocol::DNSCryptUDP) {
- return DnstapMessage::ProtocolType::DNSCryptUDP;
- }
- if (protocol == dnsdist::Protocol::DNSCryptTCP) {
- return DnstapMessage::ProtocolType::DNSCryptTCP;
- }
- if (protocol == dnsdist::Protocol::DoQ) {
- return DnstapMessage::ProtocolType::DoQ;
- }
- throw std::runtime_error("Unhandled protocol for dnstap: " + protocol.toPrettyString());
-}
-
-static void remoteLoggerQueueData(RemoteLoggerInterface& remoteLogger, const std::string& data)
-{
- auto ret = remoteLogger.queueData(data);
-
- switch (ret) {
- case RemoteLoggerInterface::Result::Queued:
- break;
- case RemoteLoggerInterface::Result::PipeFull: {
- vinfolog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret));
- break;
- }
- case RemoteLoggerInterface::Result::TooLarge: {
- warnlog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret));
- break;
- }
- case RemoteLoggerInterface::Result::OtherError:
- warnlog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret));
- }
-}
-
-class DnstapLogAction : public DNSAction, public boost::noncopyable
-{
-public:
- // this action does not stop the processing
- DnstapLogAction(std::string identity, std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)>> alterFunc) :
- d_identity(std::move(identity)), d_logger(logger), d_alterFunc(std::move(alterFunc))
- {
- }
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- static thread_local std::string data;
- data.clear();
-
- DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(dnsquestion->getProtocol());
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
- DnstapMessage message(std::move(data), !dnsquestion->getHeader()->qr ? DnstapMessage::MessageType::client_query : DnstapMessage::MessageType::client_response, d_identity, &dnsquestion->ids.origRemote, &dnsquestion->ids.origDest, protocol, reinterpret_cast<const char*>(dnsquestion->getData().data()), dnsquestion->getData().size(), &dnsquestion->getQueryRealTime(), nullptr);
- {
- if (d_alterFunc) {
- auto lock = g_lua.lock();
- (*d_alterFunc)(dnsquestion, &message);
- }
- }
-
- data = message.getBuffer();
- remoteLoggerQueueData(*d_logger, data);
-
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "remote log as dnstap to " + (d_logger ? d_logger->toString() : "");
- }
-
-private:
- std::string d_identity;
- std::shared_ptr<RemoteLoggerInterface> d_logger;
- boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)>> d_alterFunc;
-};
-
-namespace
-{
-void addMetaDataToProtobuf(DNSDistProtoBufMessage& message, const DNSQuestion& dnsquestion, const std::vector<std::pair<std::string, ProtoBufMetaKey>>& metas)
-{
- for (const auto& [name, meta] : metas) {
- message.addMeta(name, meta.getValues(dnsquestion), {});
- }
-}
-
-void addTagsToProtobuf(DNSDistProtoBufMessage& message, const DNSQuestion& dnsquestion, const std::unordered_set<std::string>& allowed)
-{
- if (!dnsquestion.ids.qTag) {
- return;
- }
-
- for (const auto& [key, value] : *dnsquestion.ids.qTag) {
- if (!allowed.empty() && allowed.count(key) == 0) {
- continue;
- }
-
- if (value.empty()) {
- message.addTag(key);
- }
- else {
- auto tag = key;
- tag.append(":");
- tag.append(value);
- message.addTag(tag);
- }
- }
-}
-
-void addExtendedDNSErrorToProtobuf(DNSDistProtoBufMessage& message, const DNSResponse& response, const std::string& metaKey)
-{
- auto [infoCode, extraText] = dnsdist::edns::getExtendedDNSError(response.getData());
- if (!infoCode) {
- return;
- }
-
- if (extraText) {
- message.addMeta(metaKey, {*extraText}, {*infoCode});
- }
- else {
- message.addMeta(metaKey, {}, {*infoCode});
- }
-}
-}
-
-struct RemoteLogActionConfiguration
-{
- std::vector<std::pair<std::string, ProtoBufMetaKey>> metas;
- std::optional<std::unordered_set<std::string>> tagsToExport{std::nullopt};
- boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)>> alterQueryFunc{boost::none};
- boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)>> alterResponseFunc{boost::none};
- std::shared_ptr<RemoteLoggerInterface> logger;
- std::string serverID;
- std::string ipEncryptKey;
- std::optional<std::string> exportExtendedErrorsToMeta{std::nullopt};
- bool includeCNAME{false};
-};
-
-class RemoteLogAction : public DNSAction, public boost::noncopyable
-{
-public:
- // this action does not stop the processing
- RemoteLogAction(RemoteLogActionConfiguration& config) :
- d_tagsToExport(std::move(config.tagsToExport)), d_metas(std::move(config.metas)), d_logger(config.logger), d_alterFunc(std::move(config.alterQueryFunc)), d_serverID(config.serverID), d_ipEncryptKey(config.ipEncryptKey)
- {
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- if (!dnsquestion->ids.d_protoBufData) {
- dnsquestion->ids.d_protoBufData = std::make_unique<InternalQueryState::ProtoBufData>();
- }
- if (!dnsquestion->ids.d_protoBufData->uniqueId) {
- dnsquestion->ids.d_protoBufData->uniqueId = getUniqueID();
- }
-
- DNSDistProtoBufMessage message(*dnsquestion);
- if (!d_serverID.empty()) {
- message.setServerIdentity(d_serverID);
- }
-
-#ifdef HAVE_IPCIPHER
- if (!d_ipEncryptKey.empty()) {
- message.setRequestor(encryptCA(dnsquestion->ids.origRemote, d_ipEncryptKey));
- }
-#endif /* HAVE_IPCIPHER */
-
- if (d_tagsToExport) {
- addTagsToProtobuf(message, *dnsquestion, *d_tagsToExport);
- }
-
- addMetaDataToProtobuf(message, *dnsquestion, d_metas);
-
- if (d_alterFunc) {
- auto lock = g_lua.lock();
- (*d_alterFunc)(dnsquestion, &message);
- }
-
- static thread_local std::string data;
- data.clear();
- message.serialize(data);
- remoteLoggerQueueData(*d_logger, data);
-
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "remote log to " + (d_logger ? d_logger->toString() : "");
- }
-
-private:
- std::optional<std::unordered_set<std::string>> d_tagsToExport;
- std::vector<std::pair<std::string, ProtoBufMetaKey>> d_metas;
- std::shared_ptr<RemoteLoggerInterface> d_logger;
- boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)>> d_alterFunc;
- std::string d_serverID;
- std::string d_ipEncryptKey;
-};
-
-#endif /* DISABLE_PROTOBUF */
-
-class SNMPTrapAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- SNMPTrapAction(std::string reason) :
- d_reason(std::move(reason))
- {
- }
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- if (g_snmpAgent != nullptr && dnsdist::configuration::getImmutableConfiguration().d_snmpTrapsEnabled) {
- g_snmpAgent->sendDNSTrap(*dnsquestion, d_reason);
- }
-
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "send SNMP trap";
- }
-
-private:
- std::string d_reason;
-};
-
-class SetTagAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- SetTagAction(std::string tag, std::string value) :
- d_tag(std::move(tag)), d_value(std::move(value))
- {
- }
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- dnsquestion->setTag(d_tag, d_value);
-
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "set tag '" + d_tag + "' to value '" + d_value + "'";
- }
-
-private:
- std::string d_tag;
- std::string d_value;
-};
-
-#ifndef DISABLE_PROTOBUF
-class DnstapLogResponseAction : public DNSResponseAction, public boost::noncopyable
-{
-public:
- // this action does not stop the processing
- DnstapLogResponseAction(std::string identity, std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSResponse*, DnstapMessage*)>> alterFunc) :
- d_identity(std::move(identity)), d_logger(logger), d_alterFunc(std::move(alterFunc))
- {
- }
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- static thread_local std::string data;
- struct timespec now = {};
- gettime(&now, true);
- data.clear();
-
- DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(response->getProtocol());
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
- DnstapMessage message(std::move(data), DnstapMessage::MessageType::client_response, d_identity, &response->ids.origRemote, &response->ids.origDest, protocol, reinterpret_cast<const char*>(response->getData().data()), response->getData().size(), &response->getQueryRealTime(), &now);
- {
- if (d_alterFunc) {
- auto lock = g_lua.lock();
- (*d_alterFunc)(response, &message);
- }
- }
-
- data = message.getBuffer();
- remoteLoggerQueueData(*d_logger, data);
-
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "log response as dnstap to " + (d_logger ? d_logger->toString() : "");
- }
-
-private:
- std::string d_identity;
- std::shared_ptr<RemoteLoggerInterface> d_logger;
- boost::optional<std::function<void(DNSResponse*, DnstapMessage*)>> d_alterFunc;
-};
-
-class RemoteLogResponseAction : public DNSResponseAction, public boost::noncopyable
-{
-public:
- // this action does not stop the processing
- RemoteLogResponseAction(RemoteLogActionConfiguration& config) :
- d_tagsToExport(std::move(config.tagsToExport)), d_metas(std::move(config.metas)), d_logger(config.logger), d_alterFunc(std::move(config.alterResponseFunc)), d_serverID(config.serverID), d_ipEncryptKey(config.ipEncryptKey), d_exportExtendedErrorsToMeta(std::move(config.exportExtendedErrorsToMeta)), d_includeCNAME(config.includeCNAME)
- {
- }
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- if (!response->ids.d_protoBufData) {
- response->ids.d_protoBufData = std::make_unique<InternalQueryState::ProtoBufData>();
- }
- if (!response->ids.d_protoBufData->uniqueId) {
- response->ids.d_protoBufData->uniqueId = getUniqueID();
- }
-
- DNSDistProtoBufMessage message(*response, d_includeCNAME);
- if (!d_serverID.empty()) {
- message.setServerIdentity(d_serverID);
- }
-
-#ifdef HAVE_IPCIPHER
- if (!d_ipEncryptKey.empty()) {
- message.setRequestor(encryptCA(response->ids.origRemote, d_ipEncryptKey));
- }
-#endif /* HAVE_IPCIPHER */
-
- if (d_tagsToExport) {
- addTagsToProtobuf(message, *response, *d_tagsToExport);
- }
-
- addMetaDataToProtobuf(message, *response, d_metas);
-
- if (d_exportExtendedErrorsToMeta) {
- addExtendedDNSErrorToProtobuf(message, *response, *d_exportExtendedErrorsToMeta);
- }
-
- if (d_alterFunc) {
- auto lock = g_lua.lock();
- (*d_alterFunc)(response, &message);
- }
-
- static thread_local std::string data;
- data.clear();
- message.serialize(data);
- d_logger->queueData(data);
-
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "remote log response to " + (d_logger ? d_logger->toString() : "");
- }
-
-private:
- std::optional<std::unordered_set<std::string>> d_tagsToExport;
- std::vector<std::pair<std::string, ProtoBufMetaKey>> d_metas;
- std::shared_ptr<RemoteLoggerInterface> d_logger;
- boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)>> d_alterFunc;
- std::string d_serverID;
- std::string d_ipEncryptKey;
- std::optional<std::string> d_exportExtendedErrorsToMeta{std::nullopt};
- bool d_includeCNAME;
-};
-
-#endif /* DISABLE_PROTOBUF */
-
-class DropResponseAction : public DNSResponseAction
-{
-public:
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- return Action::Drop;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "drop";
- }
-};
-
-class AllowResponseAction : public DNSResponseAction
-{
-public:
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- return Action::Allow;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "allow";
- }
-};
-
-class DelayResponseAction : public DNSResponseAction
-{
-public:
- DelayResponseAction(int msec) :
- d_msec(msec)
- {
- }
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- *ruleresult = std::to_string(d_msec);
- return Action::Delay;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "delay by " + std::to_string(d_msec) + " ms";
- }
-
-private:
- int d_msec;
-};
-
-#ifdef HAVE_NET_SNMP
-class SNMPTrapResponseAction : public DNSResponseAction
-{
-public:
- // this action does not stop the processing
- SNMPTrapResponseAction(std::string reason) :
- d_reason(std::move(reason))
- {
- }
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- if (g_snmpAgent != nullptr && dnsdist::configuration::getImmutableConfiguration().d_snmpTrapsEnabled) {
- g_snmpAgent->sendDNSTrap(*response, d_reason);
- }
-
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "send SNMP trap";
- }
-
-private:
- std::string d_reason;
-};
-#endif /* HAVE_NET_SNMP */
-
-class SetTagResponseAction : public DNSResponseAction
-{
-public:
- // this action does not stop the processing
- SetTagResponseAction(std::string tag, std::string value) :
- d_tag(std::move(tag)), d_value(std::move(value))
- {
- }
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- response->setTag(d_tag, d_value);
-
- return Action::None;
- }
- [[nodiscard]] std::string toString() const override
- {
- return "set tag '" + d_tag + "' to value '" + d_value + "'";
- }
-
-private:
- std::string d_tag;
- std::string d_value;
-};
-
-class ClearRecordTypesResponseAction : public DNSResponseAction, public boost::noncopyable
-{
-public:
- ClearRecordTypesResponseAction(std::unordered_set<QType> qtypes) :
- d_qtypes(std::move(qtypes))
- {
- }
-
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- if (!d_qtypes.empty()) {
- clearDNSPacketRecordTypes(response->getMutableData(), d_qtypes);
- }
- return DNSResponseAction::Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "clear record types";
- }
-
-private:
- std::unordered_set<QType> d_qtypes{};
-};
-
-class ContinueAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- ContinueAction(std::shared_ptr<DNSAction>& action) :
- d_action(action)
- {
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- if (d_action) {
- /* call the action */
- auto action = (*d_action)(dnsquestion, ruleresult);
- bool drop = false;
- /* apply the changes if needed (pool selection, flags, etc */
- processRulesResult(action, *dnsquestion, *ruleresult, drop);
- }
-
- /* but ignore the resulting action no matter what */
- return Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- if (d_action) {
- return "continue after: " + (d_action ? d_action->toString() : "");
- }
- return "no op";
- }
-
-private:
- std::shared_ptr<DNSAction> d_action;
-};
-
-#if defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_HTTP3)
-class HTTPStatusAction : public DNSAction
-{
-public:
- HTTPStatusAction(int code, PacketBuffer body, std::string contentType, dnsdist::ResponseConfig responseConfig) :
- d_responseConfig(responseConfig), d_body(std::move(body)), d_contentType(std::move(contentType)), d_code(code)
- {
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
-#if defined(HAVE_DNS_OVER_HTTPS)
- if (dnsquestion->ids.du) {
- dnsquestion->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
- dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
- header.qr = true; // for good measure
- setResponseHeadersFromConfig(header, d_responseConfig);
- return true;
- });
- return Action::HeaderModify;
- }
-#endif /* HAVE_DNS_OVER_HTTPS */
-#if defined(HAVE_DNS_OVER_HTTP3)
- if (dnsquestion->ids.doh3u) {
- dnsquestion->ids.doh3u->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
- dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
- header.qr = true; // for good measure
- setResponseHeadersFromConfig(header, d_responseConfig);
- return true;
- });
- return Action::HeaderModify;
- }
-#endif /* HAVE_DNS_OVER_HTTP3 */
- return Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "return an HTTP status of " + std::to_string(d_code);
- }
-
-private:
- dnsdist::ResponseConfig d_responseConfig;
- PacketBuffer d_body;
- std::string d_contentType;
- int d_code;
-};
-#endif /* HAVE_DNS_OVER_HTTPS || HAVE_DNS_OVER_HTTP3 */
-
-#if defined(HAVE_LMDB) || defined(HAVE_CDB)
-class KeyValueStoreLookupAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- KeyValueStoreLookupAction(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, std::string destinationTag) :
- d_kvs(kvs), d_key(lookupKey), d_tag(std::move(destinationTag))
- {
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- std::vector<std::string> keys = d_key->getKeys(*dnsquestion);
- std::string result;
- for (const auto& key : keys) {
- if (d_kvs->getValue(key, result)) {
- break;
- }
- }
-
- dnsquestion->setTag(d_tag, std::move(result));
-
- return Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "lookup key-value store based on '" + d_key->toString() + "' and set the result in tag '" + d_tag + "'";
- }
-
-private:
- std::shared_ptr<KeyValueStore> d_kvs;
- std::shared_ptr<KeyValueLookupKey> d_key;
- std::string d_tag;
-};
-
-class KeyValueStoreRangeLookupAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- KeyValueStoreRangeLookupAction(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, std::string destinationTag) :
- d_kvs(kvs), d_key(lookupKey), d_tag(std::move(destinationTag))
- {
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- std::vector<std::string> keys = d_key->getKeys(*dnsquestion);
- std::string result;
- for (const auto& key : keys) {
- if (d_kvs->getRangeValue(key, result)) {
- break;
- }
- }
-
- dnsquestion->setTag(d_tag, std::move(result));
-
- return Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "do a range-based lookup in key-value store based on '" + d_key->toString() + "' and set the result in tag '" + d_tag + "'";
- }
-
-private:
- std::shared_ptr<KeyValueStore> d_kvs;
- std::shared_ptr<KeyValueLookupKey> d_key;
- std::string d_tag;
-};
-#endif /* defined(HAVE_LMDB) || defined(HAVE_CDB) */
-
-class MaxReturnedTTLAction : public DNSAction
-{
-public:
- MaxReturnedTTLAction(uint32_t cap) :
- d_cap(cap)
- {
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- dnsquestion->ids.ttlCap = d_cap;
- return DNSAction::Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "cap the TTL of the returned response to " + std::to_string(d_cap);
- }
-
-private:
- uint32_t d_cap;
-};
-
-class MaxReturnedTTLResponseAction : public DNSResponseAction
-{
-public:
- MaxReturnedTTLResponseAction(uint32_t cap) :
- d_cap(cap)
- {
- }
-
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- response->ids.ttlCap = d_cap;
- return DNSResponseAction::Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "cap the TTL of the returned response to " + std::to_string(d_cap);
- }
-
-private:
- uint32_t d_cap;
-};
-
-class NegativeAndSOAAction : public DNSAction
-{
-public:
- struct SOAParams
- {
- uint32_t serial;
- uint32_t refresh;
- uint32_t retry;
- uint32_t expire;
- uint32_t minimum;
- };
-
- NegativeAndSOAAction(bool nxd, DNSName zone, uint32_t ttl, DNSName mname, DNSName rname, SOAParams params, bool soaInAuthoritySection, dnsdist::ResponseConfig responseConfig) :
- d_responseConfig(responseConfig), d_zone(std::move(zone)), d_mname(std::move(mname)), d_rname(std::move(rname)), d_ttl(ttl), d_params(params), d_nxd(nxd), d_soaInAuthoritySection(soaInAuthoritySection)
- {
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- if (!setNegativeAndAdditionalSOA(*dnsquestion, d_nxd, d_zone, d_ttl, d_mname, d_rname, d_params.serial, d_params.refresh, d_params.retry, d_params.expire, d_params.minimum, d_soaInAuthoritySection)) {
- return Action::None;
- }
-
- dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
- setResponseHeadersFromConfig(header, d_responseConfig);
- return true;
- });
-
- return Action::Allow;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return std::string(d_nxd ? "NXD" : "NODATA") + " with SOA";
- }
-
-private:
- dnsdist::ResponseConfig d_responseConfig;
-
- DNSName d_zone;
- DNSName d_mname;
- DNSName d_rname;
- uint32_t d_ttl;
- SOAParams d_params;
- bool d_nxd;
- bool d_soaInAuthoritySection;
-};
-
-class SetProxyProtocolValuesAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- SetProxyProtocolValuesAction(const std::vector<std::pair<uint8_t, std::string>>& values)
- {
- d_values.reserve(values.size());
- for (const auto& value : values) {
- d_values.push_back({value.second, value.first});
- }
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- if (!dnsquestion->proxyProtocolValues) {
- dnsquestion->proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>();
- }
-
- *(dnsquestion->proxyProtocolValues) = d_values;
-
- return Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "set Proxy-Protocol values";
- }
-
-private:
- std::vector<ProxyProtocolValue> d_values;
-};
-
-class SetAdditionalProxyProtocolValueAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- SetAdditionalProxyProtocolValueAction(uint8_t type, std::string value) :
- d_value(std::move(value)), d_type(type)
- {
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
- {
- if (!dnsquestion->proxyProtocolValues) {
- dnsquestion->proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>();
- }
-
- dnsquestion->proxyProtocolValues->push_back({d_value, d_type});
-
- return Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "add a Proxy-Protocol value of type " + std::to_string(d_type);
- }
-
-private:
- std::string d_value;
- uint8_t d_type;
-};
-
-class SetReducedTTLResponseAction : public DNSResponseAction, public boost::noncopyable
-{
-public:
- // this action does not stop the processing
- SetReducedTTLResponseAction(uint8_t percentage) :
- d_ratio(percentage / 100.0)
- {
- }
-
- DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override
- {
- // NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
- auto visitor = [&](uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl) {
- return ttl * d_ratio;
- };
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
- editDNSPacketTTL(reinterpret_cast<char*>(response->getMutableData().data()), response->getData().size(), visitor);
- return DNSResponseAction::Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "reduce ttl to " + std::to_string(d_ratio * 100) + " percent of its value";
- }
-
-private:
- double d_ratio{1.0};
-};
-
-class SetExtendedDNSErrorAction : public DNSAction
-{
-public:
- // this action does not stop the processing
- SetExtendedDNSErrorAction(uint16_t infoCode, const std::string& extraText)
- {
- d_ede.infoCode = infoCode;
- d_ede.extraText = extraText;
- }
-
- DNSAction::Action operator()(DNSQuestion* dnsQuestion, std::string* ruleresult) const override
- {
- dnsQuestion->ids.d_extendedError = std::make_unique<EDNSExtendedError>(d_ede);
-
- return DNSAction::Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "set EDNS Extended DNS Error to " + std::to_string(d_ede.infoCode) + (d_ede.extraText.empty() ? std::string() : std::string(": \"") + d_ede.extraText + std::string("\""));
- }
-
-private:
- EDNSExtendedError d_ede;
-};
-
-class SetExtendedDNSErrorResponseAction : public DNSResponseAction
-{
-public:
- // this action does not stop the processing
- SetExtendedDNSErrorResponseAction(uint16_t infoCode, const std::string& extraText)
- {
- d_ede.infoCode = infoCode;
- d_ede.extraText = extraText;
- }
-
- DNSResponseAction::Action operator()(DNSResponse* dnsResponse, std::string* ruleresult) const override
- {
- dnsResponse->ids.d_extendedError = std::make_unique<EDNSExtendedError>(d_ede);
-
- return DNSResponseAction::Action::None;
- }
-
- [[nodiscard]] std::string toString() const override
- {
- return "set EDNS Extended DNS Error to " + std::to_string(d_ede.infoCode) + (d_ede.extraText.empty() ? std::string() : std::string(": \"") + d_ede.extraText + std::string("\""));
- }
-
-private:
- EDNSExtendedError d_ede;
-};
+#include "dnstap.hh"
+#include "remote_logger.hh"
template <typename ActionT, typename IdentifierT>
static void addAction(IdentifierT identifier, const luadnsrule_t& var, const std::shared_ptr<ActionT>& action, boost::optional<luaruleparams_t>& params)
return config;
}
+template <class T>
+static std::vector<T> convertLuaArrayToRegular(const LuaArray<T>& luaArray)
+{
+ std::vector<T> out;
+ out.reserve(luaArray.size());
+ for (const auto& entry : luaArray) {
+ out.emplace_back(entry.second);
+ }
+ return out;
+}
+
// NOLINTNEXTLINE(readability-function-cognitive-complexity): this function declares Lua bindings, even with a good refactoring it will likely blow up the threshold
void setupLuaActions(LuaContext& luaCtx)
{
luaCtx.registerFunction("reload", &DNSAction::reload);
luaCtx.registerFunction("reload", &DNSResponseAction::reload);
- luaCtx.writeFunction("LuaAction", [](LuaAction::func_t func) {
- setLuaSideEffect();
- return std::shared_ptr<DNSAction>(new LuaAction(std::move(func)));
- });
-
- luaCtx.writeFunction("LuaFFIAction", [](LuaFFIAction::func_t func) {
- setLuaSideEffect();
- return std::shared_ptr<DNSAction>(new LuaFFIAction(std::move(func)));
- });
-
- luaCtx.writeFunction("LuaFFIPerThreadAction", [](const std::string& code) {
- setLuaSideEffect();
- return std::shared_ptr<DNSAction>(new LuaFFIPerThreadAction(code));
- });
-
- luaCtx.writeFunction("SetNoRecurseAction", []() {
- return std::shared_ptr<DNSAction>(new SetNoRecurseAction);
- });
-
- luaCtx.writeFunction("SetMacAddrAction", [](int code) {
- return std::shared_ptr<DNSAction>(new SetMacAddrAction(code));
- });
-
- luaCtx.writeFunction("SetEDNSOptionAction", [](int code, const std::string& data) {
- return std::shared_ptr<DNSAction>(new SetEDNSOptionAction(code, data));
- });
-
- luaCtx.writeFunction("PoolAction", [](const std::string& poolname, boost::optional<bool> stopProcessing) {
- return std::shared_ptr<DNSAction>(new PoolAction(poolname, stopProcessing ? *stopProcessing : true));
- });
-
- luaCtx.writeFunction("QPSAction", [](int limit) {
- return std::shared_ptr<DNSAction>(new QPSAction(limit));
- });
-
- luaCtx.writeFunction("QPSPoolAction", [](int limit, const std::string& poolname, boost::optional<bool> stopProcessing) {
- return std::shared_ptr<DNSAction>(new QPSPoolAction(limit, poolname, stopProcessing ? *stopProcessing : true));
- });
-
luaCtx.writeFunction("SpoofAction", [](LuaTypeOrArrayOf<std::string> inp, boost::optional<responseParams_t> vars) {
vector<ComboAddress> addrs;
if (auto* ipaddr = boost::get<std::string>(&inp)) {
auto responseConfig = parseResponseConfig(vars);
checkAllParametersConsumed("SpoofAction", vars);
- auto ret = std::shared_ptr<DNSAction>(new SpoofAction(addrs, responseConfig));
+ auto ret = dnsdist::actions::getSpoofAction(addrs, responseConfig);
return ret;
});
luaCtx.writeFunction("SpoofSVCAction", [](const LuaArray<SVCRecordParameters>& parameters, boost::optional<responseParams_t> vars) {
auto responseConfig = parseResponseConfig(vars);
checkAllParametersConsumed("SpoofAction", vars);
- auto ret = std::shared_ptr<DNSAction>(new SpoofSVCAction(parameters, responseConfig));
+ auto svcParams = convertLuaArrayToRegular(parameters);
+ auto ret = dnsdist::actions::getSpoofSVCAction(svcParams, responseConfig);
return ret;
});
luaCtx.writeFunction("SpoofCNAMEAction", [](const std::string& cname, boost::optional<responseParams_t> vars) {
auto responseConfig = parseResponseConfig(vars);
checkAllParametersConsumed("SpoofCNAMEAction", vars);
- auto ret = std::shared_ptr<DNSAction>(new SpoofAction(DNSName(cname), responseConfig));
+ auto ret = dnsdist::actions::getSpoofAction(DNSName(cname), responseConfig);
return ret;
});
}
auto responseConfig = parseResponseConfig(vars);
checkAllParametersConsumed("SpoofRawAction", vars);
- auto ret = std::shared_ptr<DNSAction>(new SpoofAction(raws, qtypeForAnyParam, responseConfig));
+ auto ret = dnsdist::actions::getSpoofAction(raws, qtypeForAnyParam, responseConfig);
return ret;
});
throw std::runtime_error(std::string("SpoofPacketAction: given packet len is too small"));
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic)
- auto ret = std::shared_ptr<DNSAction>(new SpoofAction(PacketBuffer(response.data(), response.data() + len)));
+ auto ret = dnsdist::actions::getSpoofAction(PacketBuffer(response.data(), response.data() + len));
return ret;
});
- luaCtx.writeFunction("DropAction", []() {
- return std::shared_ptr<DNSAction>(new DropAction);
- });
-
- luaCtx.writeFunction("AllowAction", []() {
- return std::shared_ptr<DNSAction>(new AllowAction);
- });
-
- luaCtx.writeFunction("NoneAction", []() {
- return std::shared_ptr<DNSAction>(new NoneAction);
- });
-
- luaCtx.writeFunction("DelayAction", [](int msec) {
- return std::shared_ptr<DNSAction>(new DelayAction(msec));
- });
-
- luaCtx.writeFunction("TCAction", []() {
- return std::shared_ptr<DNSAction>(new TCAction);
- });
-
- luaCtx.writeFunction("TCResponseAction", []() {
- return std::shared_ptr<DNSResponseAction>(new TCResponseAction);
- });
-
- luaCtx.writeFunction("SetDisableValidationAction", []() {
- return std::shared_ptr<DNSAction>(new SetDisableValidationAction);
- });
-
- luaCtx.writeFunction("LogAction", [](boost::optional<std::string> fname, boost::optional<bool> binary, boost::optional<bool> append, boost::optional<bool> buffered, boost::optional<bool> verboseOnly, boost::optional<bool> includeTimestamp) {
- return std::shared_ptr<DNSAction>(new LogAction(fname ? *fname : "", binary ? *binary : true, append ? *append : false, buffered ? *buffered : false, verboseOnly ? *verboseOnly : true, includeTimestamp ? *includeTimestamp : false));
- });
-
- luaCtx.writeFunction("LogResponseAction", [](boost::optional<std::string> fname, boost::optional<bool> append, boost::optional<bool> buffered, boost::optional<bool> verboseOnly, boost::optional<bool> includeTimestamp) {
- return std::shared_ptr<DNSResponseAction>(new LogResponseAction(fname ? *fname : "", append ? *append : false, buffered ? *buffered : false, verboseOnly ? *verboseOnly : true, includeTimestamp ? *includeTimestamp : false));
- });
-
luaCtx.writeFunction("LimitTTLResponseAction", [](uint32_t min, uint32_t max, boost::optional<LuaArray<uint16_t>> types) {
std::unordered_set<QType> capTypes;
if (types) {
capTypes.insert(QType(type));
}
}
- return std::shared_ptr<DNSResponseAction>(new LimitTTLResponseAction(min, max, capTypes));
+ return dnsdist::actions::getLimitTTLResponseAction(min, max, capTypes);
});
luaCtx.writeFunction("SetMinTTLResponseAction", [](uint32_t min) {
- return std::shared_ptr<DNSResponseAction>(new LimitTTLResponseAction(min));
+ return dnsdist::actions::getLimitTTLResponseAction(min);
});
luaCtx.writeFunction("SetMaxTTLResponseAction", [](uint32_t max) {
- return std::shared_ptr<DNSResponseAction>(new LimitTTLResponseAction(0, max));
+ return dnsdist::actions::getLimitTTLResponseAction(0, max);
});
luaCtx.writeFunction("SetMaxReturnedTTLAction", [](uint32_t max) {
- return std::shared_ptr<DNSAction>(new MaxReturnedTTLAction(max));
+ return dnsdist::actions::getSetMaxReturnedTTLAction(max);
});
luaCtx.writeFunction("SetMaxReturnedTTLResponseAction", [](uint32_t max) {
- return std::shared_ptr<DNSResponseAction>(new MaxReturnedTTLResponseAction(max));
+ return dnsdist::actions::getSetMaxReturnedTTLResponseAction(max);
});
luaCtx.writeFunction("SetReducedTTLResponseAction", [](uint8_t percentage) {
if (percentage > 100) {
throw std::runtime_error(std::string("SetReducedTTLResponseAction takes a percentage between 0 and 100."));
}
- return std::shared_ptr<DNSResponseAction>(new SetReducedTTLResponseAction(percentage));
+ return dnsdist::actions::getSetReducedTTLResponseAction(percentage);
});
luaCtx.writeFunction("ClearRecordTypesResponseAction", [](LuaTypeOrArrayOf<int> types) {
qtypes.insert(tpair.second);
}
}
- return std::shared_ptr<DNSResponseAction>(new ClearRecordTypesResponseAction(std::move(qtypes)));
+ return dnsdist::actions::getClearRecordTypesResponseAction(std::move(qtypes));
});
luaCtx.writeFunction("RCodeAction", [](uint8_t rcode, boost::optional<responseParams_t> vars) {
auto responseConfig = parseResponseConfig(vars);
checkAllParametersConsumed("RCodeAction", vars);
- auto ret = std::shared_ptr<DNSAction>(new RCodeAction(rcode, responseConfig));
+ auto ret = dnsdist::actions::getRCodeAction(rcode, responseConfig);
return ret;
});
luaCtx.writeFunction("ERCodeAction", [](uint8_t rcode, boost::optional<responseParams_t> vars) {
auto responseConfig = parseResponseConfig(vars);
checkAllParametersConsumed("ERCodeAction", vars);
- auto ret = std::shared_ptr<DNSAction>(new ERCodeAction(rcode, responseConfig));
+ auto ret = dnsdist::actions::getERCodeAction(rcode, responseConfig);
return ret;
});
- luaCtx.writeFunction("SetSkipCacheAction", []() {
- return std::shared_ptr<DNSAction>(new SetSkipCacheAction);
- });
-
- luaCtx.writeFunction("SetSkipCacheResponseAction", []() {
- return std::shared_ptr<DNSResponseAction>(new SetSkipCacheResponseAction);
- });
-
- luaCtx.writeFunction("SetTempFailureCacheTTLAction", [](int maxTTL) {
- return std::shared_ptr<DNSAction>(new SetTempFailureCacheTTLAction(maxTTL));
- });
-
- luaCtx.writeFunction("DropResponseAction", []() {
- return std::shared_ptr<DNSResponseAction>(new DropResponseAction);
- });
-
- luaCtx.writeFunction("AllowResponseAction", []() {
- return std::shared_ptr<DNSResponseAction>(new AllowResponseAction);
- });
-
- luaCtx.writeFunction("DelayResponseAction", [](int msec) {
- return std::shared_ptr<DNSResponseAction>(new DelayResponseAction(msec));
- });
-
- luaCtx.writeFunction("LuaResponseAction", [](LuaResponseAction::func_t func) {
- setLuaSideEffect();
- return std::shared_ptr<DNSResponseAction>(new LuaResponseAction(std::move(func)));
- });
-
- luaCtx.writeFunction("LuaFFIResponseAction", [](LuaFFIResponseAction::func_t func) {
- setLuaSideEffect();
- return std::shared_ptr<DNSResponseAction>(new LuaFFIResponseAction(std::move(func)));
- });
-
- luaCtx.writeFunction("LuaFFIPerThreadResponseAction", [](const std::string& code) {
- setLuaSideEffect();
- return std::shared_ptr<DNSResponseAction>(new LuaFFIPerThreadResponseAction(code));
- });
-
#ifndef DISABLE_PROTOBUF
- luaCtx.writeFunction("RemoteLogAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)>> alterFunc, boost::optional<LuaAssociativeTable<std::string>> vars, boost::optional<LuaAssociativeTable<std::string>> metas) {
+ luaCtx.writeFunction("RemoteLogAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<dnsdist::actions::ProtobufAlterFunction> alterFunc, boost::optional<LuaAssociativeTable<std::string>> vars, boost::optional<LuaAssociativeTable<std::string>> metas) {
if (logger) {
// avoids potentially-evaluated-expression warning with clang.
RemoteLoggerInterface& remoteLoggerRef = *logger;
}
std::string tags;
- RemoteLogActionConfiguration config;
+ dnsdist::actions::RemoteLogActionConfiguration config;
config.logger = std::move(logger);
- config.alterQueryFunc = std::move(alterFunc);
+ if (alterFunc) {
+ config.alterQueryFunc = std::move(*alterFunc);
+ }
getOptionalValue<std::string>(vars, "serverID", config.serverID);
getOptionalValue<std::string>(vars, "ipEncryptKey", config.ipEncryptKey);
getOptionalValue<std::string>(vars, "exportTags", tags);
checkAllParametersConsumed("RemoteLogAction", vars);
- return std::shared_ptr<DNSAction>(new RemoteLogAction(config));
+ return dnsdist::actions::getRemoteLogAction(config);
});
- luaCtx.writeFunction("RemoteLogResponseAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)>> alterFunc, boost::optional<bool> includeCNAME, boost::optional<LuaAssociativeTable<std::string>> vars, boost::optional<LuaAssociativeTable<std::string>> metas) {
+ luaCtx.writeFunction("RemoteLogResponseAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<dnsdist::actions::ProtobufAlterResponseFunction> alterFunc, boost::optional<bool> includeCNAME, boost::optional<LuaAssociativeTable<std::string>> vars, boost::optional<LuaAssociativeTable<std::string>> metas) {
if (logger) {
// avoids potentially-evaluated-expression warning with clang.
RemoteLoggerInterface& remoteLoggerRef = *logger;
}
std::string tags;
- RemoteLogActionConfiguration config;
+ dnsdist::actions::RemoteLogActionConfiguration config;
config.logger = std::move(logger);
- config.alterResponseFunc = std::move(alterFunc);
+ if (alterFunc) {
+ config.alterResponseFunc = std::move(*alterFunc);
+ }
config.includeCNAME = includeCNAME ? *includeCNAME : false;
getOptionalValue<std::string>(vars, "serverID", config.serverID);
getOptionalValue<std::string>(vars, "ipEncryptKey", config.ipEncryptKey);
checkAllParametersConsumed("RemoteLogResponseAction", vars);
- return std::shared_ptr<DNSResponseAction>(new RemoteLogResponseAction(config));
+ return dnsdist::actions::getRemoteLogResponseAction(config);
});
- luaCtx.writeFunction("DnstapLogAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)>> alterFunc) {
- return std::shared_ptr<DNSAction>(new DnstapLogAction(identity, logger, std::move(alterFunc)));
+ luaCtx.writeFunction("DnstapLogAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<dnsdist::actions::DnstapAlterFunction> alterFunc) {
+ return dnsdist::actions::getDnstapLogAction(identity, std::move(logger), alterFunc ? std::move(*alterFunc) : std::optional<dnsdist::actions::DnstapAlterFunction>());
});
- luaCtx.writeFunction("DnstapLogResponseAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSResponse*, DnstapMessage*)>> alterFunc) {
- return std::shared_ptr<DNSResponseAction>(new DnstapLogResponseAction(identity, logger, std::move(alterFunc)));
+ luaCtx.writeFunction("DnstapLogResponseAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<dnsdist::actions::DnstapAlterResponseFunction> alterFunc) {
+ return dnsdist::actions::getDnstapLogResponseAction(identity, std::move(logger), alterFunc ? std::move(*alterFunc) : std::optional<dnsdist::actions::DnstapAlterResponseFunction>());
});
#endif /* DISABLE_PROTOBUF */
luaCtx.writeFunction("TeeAction", [](const std::string& remote, boost::optional<bool> addECS, boost::optional<std::string> local, boost::optional<bool> addProxyProtocol) {
- boost::optional<ComboAddress> localAddr{boost::none};
+ std::optional<ComboAddress> localAddr;
if (local) {
localAddr = ComboAddress(*local, 0);
}
- return std::shared_ptr<DNSAction>(new TeeAction(ComboAddress(remote, 53), localAddr, addECS ? *addECS : false, addProxyProtocol ? *addProxyProtocol : false));
- });
-
- luaCtx.writeFunction("SetECSPrefixLengthAction", [](uint16_t v4PrefixLength, uint16_t v6PrefixLength) {
- return std::shared_ptr<DNSAction>(new SetECSPrefixLengthAction(v4PrefixLength, v6PrefixLength));
- });
-
- luaCtx.writeFunction("SetECSOverrideAction", [](bool ecsOverride) {
- return std::shared_ptr<DNSAction>(new SetECSOverrideAction(ecsOverride));
- });
-
- luaCtx.writeFunction("SetDisableECSAction", []() {
- return std::shared_ptr<DNSAction>(new SetDisableECSAction());
+ return dnsdist::actions::getTeeAction(ComboAddress(remote, 53), localAddr, addECS ? *addECS : false, addProxyProtocol ? *addProxyProtocol : false);
});
luaCtx.writeFunction("SetECSAction", [](const std::string& v4Netmask, boost::optional<std::string> v6Netmask) {
if (v6Netmask) {
- return std::shared_ptr<DNSAction>(new SetECSAction(Netmask(v4Netmask), Netmask(*v6Netmask)));
+ return dnsdist::actions::getSetECSAction(v4Netmask, *v6Netmask);
}
- return std::shared_ptr<DNSAction>(new SetECSAction(Netmask(v4Netmask)));
- });
-
-#ifdef HAVE_NET_SNMP
- luaCtx.writeFunction("SNMPTrapAction", [](boost::optional<std::string> reason) {
- return std::shared_ptr<DNSAction>(new SNMPTrapAction(reason ? *reason : ""));
- });
-
- luaCtx.writeFunction("SNMPTrapResponseAction", [](boost::optional<std::string> reason) {
- return std::shared_ptr<DNSResponseAction>(new SNMPTrapResponseAction(reason ? *reason : ""));
- });
-#endif /* HAVE_NET_SNMP */
-
- luaCtx.writeFunction("SetTagAction", [](const std::string& tag, const std::string& value) {
- return std::shared_ptr<DNSAction>(new SetTagAction(tag, value));
- });
-
- luaCtx.writeFunction("SetTagResponseAction", [](const std::string& tag, const std::string& value) {
- return std::shared_ptr<DNSResponseAction>(new SetTagResponseAction(tag, value));
+ return dnsdist::actions::getSetECSAction(v4Netmask);
});
luaCtx.writeFunction("ContinueAction", [](std::shared_ptr<DNSAction> action) {
- return std::shared_ptr<DNSAction>(new ContinueAction(action));
+ return dnsdist::actions::getContinueAction(std::move(action));
});
#ifdef HAVE_DNS_OVER_HTTPS
luaCtx.writeFunction("HTTPStatusAction", [](uint16_t status, std::string body, boost::optional<std::string> contentType, boost::optional<responseParams_t> vars) {
auto responseConfig = parseResponseConfig(vars);
checkAllParametersConsumed("HTTPStatusAction", vars);
- auto ret = std::shared_ptr<DNSAction>(new HTTPStatusAction(status, PacketBuffer(body.begin(), body.end()), contentType ? *contentType : "", responseConfig));
+ auto ret = dnsdist::actions::getHTTPStatusAction(status, PacketBuffer(body.begin(), body.end()), contentType ? *contentType : "", responseConfig);
return ret;
});
#endif /* HAVE_DNS_OVER_HTTPS */
#if defined(HAVE_LMDB) || defined(HAVE_CDB)
luaCtx.writeFunction("KeyValueStoreLookupAction", [](std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, const std::string& destinationTag) {
- return std::shared_ptr<DNSAction>(new KeyValueStoreLookupAction(kvs, lookupKey, destinationTag));
+ return dnsdist::actions::getKeyValueStoreLookupAction(kvs, lookupKey, destinationTag);
});
luaCtx.writeFunction("KeyValueStoreRangeLookupAction", [](std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, const std::string& destinationTag) {
- return std::shared_ptr<DNSAction>(new KeyValueStoreRangeLookupAction(kvs, lookupKey, destinationTag));
+ return dnsdist::actions::getKeyValueStoreRangeLookupAction(kvs, lookupKey, destinationTag);
});
#endif /* defined(HAVE_LMDB) || defined(HAVE_CDB) */
luaCtx.writeFunction("NegativeAndSOAAction", [](bool nxd, const std::string& zone, uint32_t ttl, const std::string& mname, const std::string& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum, boost::optional<responseParams_t> vars) {
- auto responseConfig = parseResponseConfig(vars);
bool soaInAuthoritySection = false;
getOptionalValue<bool>(vars, "soaInAuthoritySection", soaInAuthoritySection);
- NegativeAndSOAAction::SOAParams params{
+ auto responseConfig = parseResponseConfig(vars);
+ checkAllParametersConsumed("NegativeAndSOAAction", vars);
+ dnsdist::actions::SOAParams params{
.serial = serial,
.refresh = refresh,
.retry = retry,
.expire = expire,
.minimum = minimum};
- checkAllParametersConsumed("NegativeAndSOAAction", vars);
- auto ret = std::shared_ptr<DNSAction>(new NegativeAndSOAAction(nxd, DNSName(zone), ttl, DNSName(mname), DNSName(rname), params, soaInAuthoritySection, responseConfig));
+ auto ret = dnsdist::actions::getNegativeAndSOAAction(nxd, DNSName(zone), ttl, DNSName(mname), DNSName(rname), params, soaInAuthoritySection, responseConfig);
return ret;
});
luaCtx.writeFunction("SetProxyProtocolValuesAction", [](const std::vector<std::pair<uint8_t, std::string>>& values) {
- return std::shared_ptr<DNSAction>(new SetProxyProtocolValuesAction(values));
- });
-
- luaCtx.writeFunction("SetAdditionalProxyProtocolValueAction", [](uint8_t type, const std::string& value) {
- return std::shared_ptr<DNSAction>(new SetAdditionalProxyProtocolValueAction(type, value));
- });
-
- luaCtx.writeFunction("SetExtendedDNSErrorAction", [](uint16_t infoCode, boost::optional<std::string> extraText) {
- return std::shared_ptr<DNSAction>(new SetExtendedDNSErrorAction(infoCode, extraText ? *extraText : ""));
+ return dnsdist::actions::getSetProxyProtocolValuesAction(values);
});
- luaCtx.writeFunction("SetExtendedDNSErrorResponseAction", [](uint16_t infoCode, boost::optional<std::string> extraText) {
- return std::shared_ptr<DNSResponseAction>(new SetExtendedDNSErrorResponseAction(infoCode, extraText ? *extraText : ""));
- });
+// NOLINTNEXTLINE(bugprone-suspicious-include)
+#include "dnsdist-lua-actions-generated.cc"
+// NOLINTNEXTLINE(bugprone-suspicious-include)
+#include "dnsdist-lua-response-actions-generated.cc"
}
*/
#include "dnsdist.hh"
#include "dnsdist-lua.hh"
+#include "dnsdist-rules.hh"
#include "dnsdist-rules-factory.hh"
#include "dnsdist-rule-chains.hh"
#include "dns_random.hh"
if (rules.empty()) {
return;
}
- //coverity[auto_causes_copy]
+ // coverity[auto_causes_copy]
auto subject = *rules.rbegin();
rules.erase(std::prev(rules.end()));
rules.insert(rules.begin(), subject);
g_outputBuffer = "Error: attempt to move rules from/to invalid index\n";
return;
}
- //coverity[auto_causes_copy]
+ // coverity[auto_causes_copy]
auto subject = rules[from];
rules.erase(rules.begin() + from);
if (destination > rules.size()) {
const auto& smn = *boost::get<const SuffixMatchNode&>(&names);
return std::shared_ptr<DNSRule>(new SuffixMatchNodeRule(smn, quiet ? *quiet : false));
}
+
+template <class T>
+std::optional<T> boostToStandardOptional(const boost::optional<T>& boostOpt)
+{
+ return boostOpt ? *boostOpt : std::optional<T>();
+}
}
// NOLINTNEXTLINE(readability-function-cognitive-complexity): this function declares Lua bindings, even with a good refactoring it will likely blow up the threshold
});
}
- luaCtx.writeFunction("MaxQPSIPRule", [](unsigned int qps, boost::optional<unsigned int> ipv4trunc, boost::optional<unsigned int> ipv6trunc, boost::optional<unsigned int> burst, boost::optional<unsigned int> expiration, boost::optional<unsigned int> cleanupDelay, boost::optional<unsigned int> scanFraction, boost::optional<unsigned int> shards) {
- return std::shared_ptr<DNSRule>(new MaxQPSIPRule(qps, (burst ? *burst : qps), (ipv4trunc ? *ipv4trunc : 32), (ipv6trunc ? *ipv6trunc : 64), (expiration ? *expiration : 300), (cleanupDelay ? *cleanupDelay : 60), (scanFraction ? *scanFraction : 10), (shards ? *shards : 10)));
- });
-
- luaCtx.writeFunction("MaxQPSRule", [](unsigned int qps, boost::optional<unsigned int> burst) {
- if (!burst) {
- return std::shared_ptr<DNSRule>(new MaxQPSRule(qps));
- }
- return std::shared_ptr<DNSRule>(new MaxQPSRule(qps, *burst));
- });
-
- luaCtx.writeFunction("RegexRule", [](const std::string& str) {
- return std::shared_ptr<DNSRule>(new RegexRule(str));
- });
-
-#ifdef HAVE_DNS_OVER_HTTPS
- luaCtx.writeFunction("HTTPHeaderRule", [](const std::string& header, const std::string& regex) {
- return std::shared_ptr<DNSRule>(new HTTPHeaderRule(header, regex));
- });
- luaCtx.writeFunction("HTTPPathRule", [](const std::string& path) {
- return std::shared_ptr<DNSRule>(new HTTPPathRule(path));
- });
- luaCtx.writeFunction("HTTPPathRegexRule", [](const std::string& regex) {
- return std::shared_ptr<DNSRule>(new HTTPPathRegexRule(regex));
- });
-#endif
-
-#ifdef HAVE_RE2
- luaCtx.writeFunction("RE2Rule", [](const std::string& str) {
- return std::shared_ptr<DNSRule>(new RE2Rule(str));
- });
-#endif
-
- luaCtx.writeFunction("SNIRule", [](const std::string& name) {
- return std::shared_ptr<DNSRule>(new SNIRule(name));
- });
-
luaCtx.writeFunction("SuffixMatchNodeRule", qnameSuffixRule);
luaCtx.writeFunction("NetmaskGroupRule", [](const boost::variant<const NetmaskGroup&, std::string, const LuaArray<std::string>> netmasks, boost::optional<bool> src, boost::optional<bool> quiet) {
g_outputBuffer = (boost::format("Had %d matches out of %d, %.1f qps, in %.1f us\n") % matches % times % (1000000 * (1.0 * times / udiff)) % udiff).str();
});
- luaCtx.writeFunction("AllRule", []() {
- return std::shared_ptr<DNSRule>(new AllRule());
- });
-
- luaCtx.writeFunction("ProbaRule", [](double proba) {
- return std::shared_ptr<DNSRule>(new ProbaRule(proba));
- });
-
- luaCtx.writeFunction("QNameRule", [](const std::string& qname) {
- return std::shared_ptr<DNSRule>(new QNameRule(DNSName(qname)));
- });
-
luaCtx.writeFunction("QNameSuffixRule", qnameSuffixRule);
luaCtx.writeFunction("QTypeRule", [](boost::variant<unsigned int, std::string> str) {
return std::shared_ptr<DNSRule>(new QClassRule(cla));
});
- luaCtx.writeFunction("OpcodeRule", [](uint64_t code) {
- checkParameterBound("OpcodeRule", code, std::numeric_limits<uint8_t>::max());
- return std::shared_ptr<DNSRule>(new OpcodeRule(code));
- });
-
- luaCtx.writeFunction("AndRule", [](const LuaArray<std::shared_ptr<DNSRule>>& rules) {
+ luaCtx.writeFunction("AndRule", [](const LuaArray<std::shared_ptr<DNSRule>>& rulePairs) {
+ std::vector<std::shared_ptr<DNSRule>> rules;
+ rules.reserve(rulePairs.size());
+ for (const auto& pair : rulePairs) {
+ rules.emplace_back(pair.second);
+ }
return std::shared_ptr<DNSRule>(new AndRule(rules));
});
- luaCtx.writeFunction("OrRule", [](const LuaArray<std::shared_ptr<DNSRule>>& rules) {
+ luaCtx.writeFunction("OrRule", [](const LuaArray<std::shared_ptr<DNSRule>>& rulePairs) {
+ std::vector<std::shared_ptr<DNSRule>> rules;
+ rules.reserve(rulePairs.size());
+ for (const auto& pair : rulePairs) {
+ rules.emplace_back(pair.second);
+ }
return std::shared_ptr<DNSRule>(new OrRule(rules));
});
- luaCtx.writeFunction("DSTPortRule", [](uint64_t port) {
- checkParameterBound("DSTPortRule", port, std::numeric_limits<uint16_t>::max());
- return std::shared_ptr<DNSRule>(new DSTPortRule(port));
- });
-
- luaCtx.writeFunction("TCPRule", [](bool tcp) {
- return std::shared_ptr<DNSRule>(new TCPRule(tcp));
- });
-
- luaCtx.writeFunction("DNSSECRule", []() {
- return std::shared_ptr<DNSRule>(new DNSSECRule());
- });
-
luaCtx.writeFunction("NotRule", [](const std::shared_ptr<DNSRule>& rule) {
return std::shared_ptr<DNSRule>(new NotRule(rule));
});
- luaCtx.writeFunction("RecordsCountRule", [](uint64_t section, uint64_t minCount, uint64_t maxCount) {
- checkParameterBound("RecordsCountRule", section, std::numeric_limits<uint8_t>::max());
- checkParameterBound("RecordsCountRule", minCount, std::numeric_limits<uint16_t>::max());
- checkParameterBound("RecordsCountRule", maxCount, std::numeric_limits<uint16_t>::max());
- return std::shared_ptr<DNSRule>(new RecordsCountRule(section, minCount, maxCount));
- });
-
- luaCtx.writeFunction("RecordsTypeCountRule", [](uint64_t section, uint64_t type, uint64_t minCount, uint64_t maxCount) {
- checkParameterBound("RecordsTypeCountRule", section, std::numeric_limits<uint8_t>::max());
- checkParameterBound("RecordsTypeCountRule", type, std::numeric_limits<uint16_t>::max());
- checkParameterBound("RecordsTypeCountRule", minCount, std::numeric_limits<uint16_t>::max());
- checkParameterBound("RecordsTypeCountRule", maxCount, std::numeric_limits<uint16_t>::max());
- return std::shared_ptr<DNSRule>(new RecordsTypeCountRule(section, type, minCount, maxCount));
- });
-
- luaCtx.writeFunction("TrailingDataRule", []() {
- return std::shared_ptr<DNSRule>(new TrailingDataRule());
- });
-
- luaCtx.writeFunction("QNameLabelsCountRule", [](uint64_t minLabelsCount, uint64_t maxLabelsCount) {
- checkParameterBound("QNameLabelsCountRule", minLabelsCount, std::numeric_limits<unsigned int>::max());
- checkParameterBound("QNameLabelsCountRule", maxLabelsCount, std::numeric_limits<unsigned int>::max());
- return std::shared_ptr<DNSRule>(new QNameLabelsCountRule(minLabelsCount, maxLabelsCount));
- });
-
- luaCtx.writeFunction("QNameWireLengthRule", [](uint64_t min, uint64_t max) {
- return std::shared_ptr<DNSRule>(new QNameWireLengthRule(min, max));
- });
-
luaCtx.writeFunction("RCodeRule", [](uint64_t rcode) {
checkParameterBound("RCodeRule", rcode, std::numeric_limits<uint8_t>::max());
return std::shared_ptr<DNSRule>(new RCodeRule(rcode));
return std::shared_ptr<DNSRule>(new ERCodeRule(rcode));
});
- luaCtx.writeFunction("EDNSVersionRule", [](uint64_t version) {
- checkParameterBound("EDNSVersionRule", version, std::numeric_limits<uint8_t>::max());
- return std::shared_ptr<DNSRule>(new EDNSVersionRule(version));
- });
-
- luaCtx.writeFunction("EDNSOptionRule", [](uint64_t optcode) {
- checkParameterBound("EDNSOptionRule", optcode, std::numeric_limits<uint16_t>::max());
- return std::shared_ptr<DNSRule>(new EDNSOptionRule(optcode));
- });
-
- luaCtx.writeFunction("RDRule", []() {
- return std::shared_ptr<DNSRule>(new RDRule());
- });
-
- luaCtx.writeFunction("TagRule", [](const std::string& tag, boost::optional<std::string> value) {
- return std::shared_ptr<DNSRule>(new TagRule(tag, std::move(value)));
- });
-
luaCtx.writeFunction("TimedIPSetRule", []() {
return std::make_shared<TimedIPSetRule>();
});
- luaCtx.writeFunction("PoolAvailableRule", [](const std::string& poolname) {
- return std::shared_ptr<DNSRule>(new PoolAvailableRule(poolname));
- });
-
- luaCtx.writeFunction("PoolOutstandingRule", [](const std::string& poolname, uint64_t limit) {
- return std::shared_ptr<DNSRule>(new PoolOutstandingRule(poolname, limit));
- });
-
luaCtx.registerFunction<void (std::shared_ptr<TimedIPSetRule>::*)()>("clear", [](const std::shared_ptr<TimedIPSetRule>& tisr) {
tisr->clear();
});
tisr->toString();
});
+ luaCtx.writeFunction("QNameRule", [](const std::string& qname) {
+ return std::shared_ptr<DNSRule>(dnsdist::selectors::getQNameSelector(DNSName(qname)));
+ });
+
luaCtx.writeFunction("QNameSetRule", [](const DNSNameSet& names) {
return std::shared_ptr<DNSRule>(new QNameSetRule(names));
});
});
#endif /* defined(HAVE_LMDB) || defined(HAVE_CDB) */
- luaCtx.writeFunction("LuaRule", [](const LuaRule::func_t& func) {
- return std::shared_ptr<DNSRule>(new LuaRule(func));
- });
-
- luaCtx.writeFunction("LuaFFIRule", [](const LuaFFIRule::func_t& func) {
- return std::shared_ptr<DNSRule>(new LuaFFIRule(func));
- });
-
- luaCtx.writeFunction("LuaFFIPerThreadRule", [](const std::string& code) {
- return std::shared_ptr<DNSRule>(new LuaFFIPerThreadRule(code));
- });
-
- luaCtx.writeFunction("ProxyProtocolValueRule", [](uint8_t type, boost::optional<std::string> value) {
- return std::shared_ptr<DNSRule>(new ProxyProtocolValueRule(type, std::move(value)));
- });
-
- luaCtx.writeFunction("PayloadSizeRule", [](const std::string& comparison, uint16_t size) {
- return std::shared_ptr<DNSRule>(new PayloadSizeRule(comparison, size));
- });
+// NOLINTNEXTLINE(bugprone-suspicious-include)
+#include "dnsdist-lua-selectors-generated.cc"
}
*/
#pragma once
-#include "dnsdist-rules.hh"
+#include <boost/multi_index_container.hpp>
+#include <boost/multi_index/ordered_index.hpp>
+#include <boost/multi_index/sequenced_index.hpp>
+#include <boost/multi_index/key_extractors.hpp>
#include "cachecleaner.hh"
#include "dnsdist-ecs.hh"
#include "dnsdist-kvs.hh"
#include "dnsdist-lua.hh"
#include "dnsdist-lua-ffi.hh"
+#include "dnsdist-rules.hh"
#include "dolog.hh"
#include "dnsparser.hh"
#include "dns_random.hh"
+#include "uuid-utils.hh"
-#include <boost/multi_index_container.hpp>
-#include <boost/multi_index/ordered_index.hpp>
-#include <boost/multi_index/sequenced_index.hpp>
-#include <boost/multi_index/key_extractors.hpp>
+namespace dnsdist::selectors
+{
+using LuaSelectorFunction = std::function<bool(const DNSQuestion* dq)>;
+using LuaSelectorFFIFunction = std::function<bool(dnsdist_ffi_dnsquestion_t* dq)>;
+}
class MaxQPSIPRule : public DNSRule
{
public:
- MaxQPSIPRule(unsigned int qps, unsigned int burst, unsigned int ipv4trunc=32, unsigned int ipv6trunc=64, unsigned int expiration=300, unsigned int cleanupDelay=60, unsigned int scanFraction=10, size_t shardsCount=10):
- d_shards(shardsCount), d_qps(qps), d_burst(burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc), d_cleanupDelay(cleanupDelay), d_expiration(expiration), d_scanFraction(scanFraction)
+ MaxQPSIPRule(unsigned int qps, unsigned int ipv4trunc = 32, unsigned int ipv6trunc = 64, unsigned int burst = 0, unsigned int expiration = 300, unsigned int cleanupDelay = 60, unsigned int scanFraction = 10, size_t shardsCount = 10) :
+ d_shards(shardsCount), d_qps(qps), d_burst(burst == 0 ? qps : burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc), d_cleanupDelay(cleanupDelay), d_expiration(expiration), d_scanFraction(scanFraction)
{
d_cleaningUp.clear();
gettime(&d_lastCleanup, true);
}
}
- size_t cleanup(const struct timespec& cutOff, size_t* scannedCount=nullptr) const
+ size_t cleanup(const struct timespec& cutOff, size_t* scannedCount = nullptr) const
{
size_t removed = 0;
if (scannedCount != nullptr) {
for (auto& shard : d_shards) {
auto limits = shard.lock();
- const size_t toLook = std::round((1.0 * limits->size()) / d_scanFraction)+ 1;
+ const size_t toLook = std::round((1.0 * limits->size()) / d_scanFraction) + 1;
size_t lookedAt = 0;
auto& sequence = limits->get<SequencedTag>();
cleanupIfNeeded(dq->getQueryRealTime());
ComboAddress zeroport(dq->ids.origRemote);
- zeroport.sin4.sin_port=0;
+ zeroport.sin4.sin_port = 0;
zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc);
auto hash = ComboAddress::addressOnlyHash()(zeroport);
auto& shard = d_shards[hash % d_shards.size()];
string toString() const override
{
- return "IP (/"+std::to_string(d_ipv4trunc)+", /"+std::to_string(d_ipv6trunc)+") match for QPS over " + std::to_string(d_qps) + " burst "+ std::to_string(d_burst);
+ return "IP (/" + std::to_string(d_ipv4trunc) + ", /" + std::to_string(d_ipv6trunc) + ") match for QPS over " + std::to_string(d_qps) + " burst " + std::to_string(d_burst);
}
size_t getEntriesCount() const
}
private:
- struct HashedTag {};
- struct SequencedTag {};
+ struct HashedTag
+ {
+ };
+ struct SequencedTag
+ {
+ };
struct Entry
{
- Entry(const ComboAddress& addr, BasicQPSLimiter&& limiter): d_limiter(limiter), d_addr(addr)
+ Entry(const ComboAddress& addr, BasicQPSLimiter&& limiter) :
+ d_limiter(limiter), d_addr(addr)
{
}
mutable BasicQPSLimiter d_limiter;
ComboAddress d_addr;
};
- typedef multi_index_container<
+ using qpsContainer_t = multi_index_container<
Entry,
- indexed_by <
- hashed_unique<tag<HashedTag>, member<Entry,ComboAddress,&Entry::d_addr>, ComboAddress::addressOnlyHash >,
- sequenced<tag<SequencedTag> >
- >
- > qpsContainer_t;
+ indexed_by<
+ hashed_unique<tag<HashedTag>, member<Entry, ComboAddress, &Entry::d_addr>, ComboAddress::addressOnlyHash>,
+ sequenced<tag<SequencedTag>>>>;
mutable std::vector<LockGuarded<qpsContainer_t>> d_shards;
mutable struct timespec d_lastCleanup;
class MaxQPSRule : public DNSRule
{
public:
- MaxQPSRule(unsigned int qps)
- : d_qps(qps, qps)
- {}
-
- MaxQPSRule(unsigned int qps, unsigned int burst)
- : d_qps(qps, burst)
- {}
+ MaxQPSRule(unsigned int qps) :
+ d_qps(qps, qps)
+ {
+ }
+ MaxQPSRule(unsigned int qps, unsigned int burst) :
+ d_qps(qps, burst > 0 ? burst : qps)
+ {
+ }
bool matches(const DNSQuestion* qd) const override
{
return "Max " + std::to_string(d_qps.getRate()) + " qps";
}
-
private:
mutable QPSLimiter d_qps;
};
-class NMGRule : public DNSRule
+class NetmaskGroupRule : public DNSRule
{
public:
- NMGRule(const NetmaskGroup& nmg) : d_nmg(nmg) {}
-protected:
- NetmaskGroup d_nmg;
-};
-
-class NetmaskGroupRule : public NMGRule
-{
-public:
- NetmaskGroupRule(const NetmaskGroup& nmg, bool src, bool quiet = false) : NMGRule(nmg)
+ NetmaskGroupRule(const NetmaskGroup& nmg, bool src, bool quiet = false) :
+ d_nmg(nmg)
{
- d_src = src;
- d_quiet = quiet;
+ d_src = src;
+ d_quiet = quiet;
}
bool matches(const DNSQuestion* dq) const override
{
- if(!d_src) {
- return d_nmg.match(dq->ids.origDest);
+ if (!d_src) {
+ return d_nmg.match(dq->ids.origDest);
}
return d_nmg.match(dq->ids.origRemote);
}
string toString() const override
{
string ret = "Src: ";
- if(!d_src) {
- ret = "Dst: ";
+ if (!d_src) {
+ ret = "Dst: ";
}
if (d_quiet) {
return ret + "in-group";
}
return ret + d_nmg.toString();
}
+
private:
+ NetmaskGroup d_nmg;
bool d_src;
bool d_quiet;
};
class TimedIPSetRule : public DNSRule, boost::noncopyable
{
private:
- struct IPv6 {
+ struct IPv6
+ {
IPv6(const ComboAddress& ca)
{
- static_assert(sizeof(*this)==16, "IPv6 struct has wrong size");
+ static_assert(sizeof(*this) == 16, "IPv6 struct has wrong size");
memcpy((char*)this, ca.sin6.sin6_addr.s6_addr, 16);
}
bool operator==(const IPv6& rhs) const
{
- return a==rhs.a && b==rhs.b;
+ return a == rhs.a && b == rhs.b;
}
uint64_t a, b;
};
return false;
}
return time(nullptr) < fnd->second;
- } else {
+ }
+ else {
auto ip6s = d_ip6s.read_lock();
auto fnd = ip6s->find({dq->ids.origRemote});
if (fnd == ip6s->end()) {
time_t now = time(nullptr);
{
auto ip4s = d_ip4s.write_lock();
- for (auto iter = ip4s->begin(); iter != ip4s->end(); ) {
- if (iter->second < now) {
- iter = ip4s->erase(iter);
+ for (auto iter = ip4s->begin(); iter != ip4s->end();) {
+ if (iter->second < now) {
+ iter = ip4s->erase(iter);
}
- else {
- ++iter;
+ else {
+ ++iter;
}
}
}
{
auto ip6s = d_ip6s.write_lock();
- for (auto iter = ip6s->begin(); iter != ip6s->end(); ) {
- if (iter->second < now) {
- iter = ip6s->erase(iter);
+ for (auto iter = ip6s->begin(); iter != ip6s->end();) {
+ if (iter->second < now) {
+ iter = ip6s->erase(iter);
}
- else {
- ++iter;
+ else {
+ ++iter;
}
}
-
}
-
}
string toString() const override
}
}
- return "Src: "+std::to_string(count)+" ips";
+ return "Src: " + std::to_string(count) + " ips";
}
+
private:
struct IPv6Hash
{
std::size_t operator()(const IPv6& ip) const
{
- auto ah=std::hash<uint64_t>{}(ip.a);
- auto bh=std::hash<uint64_t>{}(ip.b);
- return ah & (bh<<1);
+ auto ah = std::hash<uint64_t>{}(ip.a);
+ auto bh = std::hash<uint64_t>{}(ip.b);
+ return ah & (bh << 1);
}
};
mutable SharedLockGuarded<std::unordered_map<IPv6, time_t, IPv6Hash>> d_ip6s;
mutable SharedLockGuarded<std::unordered_map<uint32_t, time_t>> d_ip4s;
};
-
class AllRule : public DNSRule
{
public:
{
return "All";
}
-
};
-
class DNSSECRule : public DNSRule
{
public:
DNSSECRule()
{
-
}
bool matches(const DNSQuestion* dq) const override
{
- return dq->getHeader()->cd || (dnsdist::getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default..
+ return dq->getHeader()->cd || (dnsdist::getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default..
}
string toString() const override
class AndRule : public DNSRule
{
public:
- AndRule(const std::vector<pair<int, std::shared_ptr<DNSRule> > >& rules)
+ AndRule(const std::vector<std::shared_ptr<DNSRule>>& rules) :
+ d_rules(rules)
{
- for (const auto& r : rules) {
- d_rules.push_back(r.second);
- }
}
bool matches(const DNSQuestion* dq) const override
string ret;
for (const auto& rule : d_rules) {
if (!ret.empty()) {
- ret+= " && ";
+ ret += " && ";
}
- ret += "("+ rule->toString()+")";
+ ret += "(" + rule->toString() + ")";
}
return ret;
}
+
private:
- std::vector<std::shared_ptr<DNSRule> > d_rules;
+ std::vector<std::shared_ptr<DNSRule>> d_rules;
};
-
class OrRule : public DNSRule
{
public:
- OrRule(const std::vector<pair<int, std::shared_ptr<DNSRule> > >& rules)
+ OrRule(const std::vector<std::shared_ptr<DNSRule>>& rules) :
+ d_rules(rules)
{
- for (const auto& r : rules) {
- d_rules.push_back(r.second);
- }
}
bool matches(const DNSQuestion* dq) const override
{
- for (const auto& rule: d_rules) {
+ for (const auto& rule : d_rules) {
if (rule->matches(dq)) {
return true;
}
string ret;
for (const auto& rule : d_rules) {
if (!ret.empty()) {
- ret+= " || ";
+ ret += " || ";
}
- ret += "("+ rule->toString()+")";
+ ret += "(" + rule->toString() + ")";
}
return ret;
}
+
private:
- std::vector<std::shared_ptr<DNSRule> > d_rules;
+ std::vector<std::shared_ptr<DNSRule>> d_rules;
};
-
class RegexRule : public DNSRule
{
public:
- RegexRule(const std::string& regex) : d_regex(regex), d_visual(regex)
+ RegexRule(const std::string& regex) :
+ d_regex(regex), d_visual(regex)
{
-
}
bool matches(const DNSQuestion* dq) const override
{
string toString() const override
{
- return "Regex: "+d_visual;
+ return "Regex: " + d_visual;
}
+
private:
Regex d_regex;
string d_visual;
};
-#ifdef HAVE_RE2
+#if defined(HAVE_RE2)
#include <re2/re2.h>
class RE2Rule : public DNSRule
{
public:
- RE2Rule(const std::string& re2) : d_re2(re2, RE2::Latin1), d_visual(re2)
+ RE2Rule(const std::string& re2) :
+ d_re2(re2, RE2::Latin1), d_visual(re2)
{
-
}
bool matches(const DNSQuestion* dq) const override
{
string toString() const override
{
- return "RE2 match: "+d_visual;
+ return "RE2 match: " + d_visual;
}
+
private:
RE2 d_re2;
string d_visual;
};
-#endif
+#else /* HAVE_RE2 */
+class RE2Rule : public DNSRule
+{
+public:
+ RE2Rule(const std::string& re2)
+ {
+ throw std::runtime_error("RE2 support is disabled");
+ }
+ bool matches(const DNSQuestion* dq) const override
+ {
+ return false;
+ }
+
+ string toString() const override
+ {
+ return "Unsupported RE2";
+ }
+};
+#endif /* HAVE_RE2 */
-#ifdef HAVE_DNS_OVER_HTTPS
class HTTPHeaderRule : public DNSRule
{
public:
HTTPHeaderRule(const std::string& header, const std::string& regex);
- bool matches(const DNSQuestion* dq) const override;
+ bool matches(const DNSQuestion* dnsQuestion) const override;
string toString() const override;
+
private:
string d_header;
Regex d_regex;
{
public:
HTTPPathRule(std::string path);
- bool matches(const DNSQuestion* dq) const override;
+ bool matches(const DNSQuestion* dnsQuestion) const override;
string toString() const override;
+
private:
string d_path;
};
{
public:
HTTPPathRegexRule(const std::string& regex);
- bool matches(const DNSQuestion* dq) const override;
+ bool matches(const DNSQuestion* dnsQuestion) const override;
string toString() const override;
+
private:
Regex d_regex;
std::string d_visual;
};
-#endif
class SNIRule : public DNSRule
{
public:
- SNIRule(const std::string& name) : d_sni(name)
+ SNIRule(const std::string& name) :
+ d_sni(name)
{
}
bool matches(const DNSQuestion* dq) const override
{
return "SNI == " + d_sni;
}
+
private:
std::string d_sni;
};
class SuffixMatchNodeRule : public DNSRule
{
public:
- SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet=false) : d_smn(smn), d_quiet(quiet)
+ SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet = false) :
+ d_smn(smn), d_quiet(quiet)
{
}
bool matches(const DNSQuestion* dq) const override
}
string toString() const override
{
- if(d_quiet)
+ if (d_quiet)
return "qname==in-set";
else
- return "qname in "+d_smn.toString();
+ return "qname in " + d_smn.toString();
}
+
private:
SuffixMatchNode d_smn;
bool d_quiet;
class QNameRule : public DNSRule
{
public:
- QNameRule(const DNSName& qname) : d_qname(qname)
+ QNameRule(const DNSName& qname) :
+ d_qname(qname)
{
}
bool matches(const DNSQuestion* dq) const override
{
- return d_qname==dq->ids.qname;
+ return d_qname == dq->ids.qname;
}
string toString() const override
{
- return "qname=="+d_qname.toString();
+ return "qname==" + d_qname.toString();
}
+
private:
DNSName d_qname;
};
-class QNameSetRule : public DNSRule {
+class QNameSetRule : public DNSRule
+{
public:
- QNameSetRule(const DNSNameSet& names) : qname_idx(names) {}
+ QNameSetRule(const DNSNameSet& names) :
+ qname_idx(names) {}
- bool matches(const DNSQuestion* dq) const override {
- return qname_idx.find(dq->ids.qname) != qname_idx.end();
- }
+ bool matches(const DNSQuestion* dq) const override
+ {
+ return qname_idx.find(dq->ids.qname) != qname_idx.end();
+ }
+
+ string toString() const override
+ {
+ std::stringstream ss;
+ ss << "qname in DNSNameSet(" << qname_idx.size() << " FQDNs)";
+ return ss.str();
+ }
- string toString() const override {
- std::stringstream ss;
- ss << "qname in DNSNameSet(" << qname_idx.size() << " FQDNs)";
- return ss.str();
- }
private:
- DNSNameSet qname_idx;
+ DNSNameSet qname_idx;
};
class QTypeRule : public DNSRule
{
public:
- QTypeRule(uint16_t qtype) : d_qtype(qtype)
+ QTypeRule(uint16_t qtype) :
+ d_qtype(qtype)
{
}
bool matches(const DNSQuestion* dq) const override
string toString() const override
{
QType qt(d_qtype);
- return "qtype=="+qt.toString();
+ return "qtype==" + qt.toString();
}
+
private:
uint16_t d_qtype;
};
class QClassRule : public DNSRule
{
public:
- QClassRule(uint16_t qclass) : d_qclass(qclass)
+ QClassRule(uint16_t qclass) :
+ d_qclass(qclass)
{
}
bool matches(const DNSQuestion* dq) const override
}
string toString() const override
{
- return "qclass=="+std::to_string(d_qclass);
+ return "qclass==" + std::to_string(d_qclass);
}
+
private:
uint16_t d_qclass;
};
class OpcodeRule : public DNSRule
{
public:
- OpcodeRule(uint8_t opcode) : d_opcode(opcode)
+ OpcodeRule(uint8_t opcode) :
+ d_opcode(opcode)
{
}
bool matches(const DNSQuestion* dq) const override
}
string toString() const override
{
- return "opcode=="+std::to_string(d_opcode);
+ return "opcode==" + std::to_string(d_opcode);
}
+
private:
uint8_t d_opcode;
};
class DSTPortRule : public DNSRule
{
public:
- DSTPortRule(uint16_t port) : d_port(port)
+ DSTPortRule(uint16_t port) :
+ d_port(port)
{
}
bool matches(const DNSQuestion* dq) const override
}
string toString() const override
{
- return "dst port=="+std::to_string(d_port);
+ return "dst port==" + std::to_string(d_port);
}
+
private:
uint16_t d_port;
};
class TCPRule : public DNSRule
{
public:
- TCPRule(bool tcp): d_tcp(tcp)
+ TCPRule(bool tcp) :
+ d_tcp(tcp)
{
}
bool matches(const DNSQuestion* dq) const override
{
return (d_tcp ? "TCP" : "UDP");
}
+
private:
bool d_tcp;
};
-
class NotRule : public DNSRule
{
public:
- NotRule(const std::shared_ptr<DNSRule>& rule): d_rule(rule)
+ NotRule(const std::shared_ptr<DNSRule>& rule) :
+ d_rule(rule)
{
}
bool matches(const DNSQuestion* dq) const override
}
string toString() const override
{
- return "!("+ d_rule->toString()+")";
+ return "!(" + d_rule->toString() + ")";
}
+
private:
std::shared_ptr<DNSRule> d_rule;
};
class RecordsCountRule : public DNSRule
{
public:
- RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount): d_minCount(minCount), d_maxCount(maxCount), d_section(section)
+ RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount) :
+ d_minCount(minCount), d_maxCount(maxCount), d_section(section)
{
}
bool matches(const DNSQuestion* dq) const override
{
uint16_t count = 0;
- switch(d_section) {
+ switch (d_section) {
case 0:
count = ntohs(dq->getHeader()->qdcount);
break;
string toString() const override
{
string section;
- switch(d_section) {
+ switch (d_section) {
case 0:
section = "QD";
break;
section = "AR";
break;
}
- return std::to_string(d_minCount) + " <= records in " + section + " <= "+ std::to_string(d_maxCount);
+ return std::to_string(d_minCount) + " <= records in " + section + " <= " + std::to_string(d_maxCount);
}
+
private:
uint16_t d_minCount;
uint16_t d_maxCount;
class RecordsTypeCountRule : public DNSRule
{
public:
- RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount): d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section)
+ RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount) :
+ d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section)
{
}
bool matches(const DNSQuestion* dq) const override
{
uint16_t count = 0;
- switch(d_section) {
+ switch (d_section) {
case 0:
count = ntohs(dq->getHeader()->qdcount);
break;
string toString() const override
{
string section;
- switch(d_section) {
+ switch (d_section) {
case 0:
section = "QD";
break;
section = "AR";
break;
}
- return std::to_string(d_minCount) + " <= " + QType(d_type).toString() + " records in " + section + " <= "+ std::to_string(d_maxCount);
+ return std::to_string(d_minCount) + " <= " + QType(d_type).toString() + " records in " + section + " <= " + std::to_string(d_maxCount);
}
+
private:
uint16_t d_type;
uint16_t d_minCount;
class QNameLabelsCountRule : public DNSRule
{
public:
- QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount)
+ QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount) :
+ d_min(minLabelsCount), d_max(maxLabelsCount)
{
}
bool matches(const DNSQuestion* dq) const override
{
return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max);
}
+
private:
unsigned int d_min;
unsigned int d_max;
class QNameWireLengthRule : public DNSRule
{
public:
- QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max)
+ QNameWireLengthRule(size_t min, size_t max) :
+ d_min(min), d_max(max)
{
}
bool matches(const DNSQuestion* dq) const override
{
return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max);
}
+
private:
size_t d_min;
size_t d_max;
class RCodeRule : public DNSRule
{
public:
- RCodeRule(uint8_t rcode) : d_rcode(rcode)
+ RCodeRule(uint8_t rcode) :
+ d_rcode(rcode)
{
}
bool matches(const DNSQuestion* dq) const override
}
string toString() const override
{
- return "rcode=="+RCode::to_s(d_rcode);
+ return "rcode==" + RCode::to_s(d_rcode);
}
+
private:
uint8_t d_rcode;
};
class ERCodeRule : public DNSRule
{
public:
- ERCodeRule(uint8_t rcode) : d_rcode(rcode & 0xF), d_extrcode(rcode >> 4)
+ ERCodeRule(uint8_t rcode) :
+ d_rcode(rcode & 0xF), d_extrcode(rcode >> 4)
{
}
bool matches(const DNSQuestion* dq) const override
}
string toString() const override
{
- return "ercode=="+ERCode::to_s(d_rcode | (d_extrcode << 4));
+ return "ercode==" + ERCode::to_s(d_rcode | (d_extrcode << 4));
}
+
private:
- uint8_t d_rcode; // plain DNS Rcode
- uint8_t d_extrcode; // upper bits in EDNS0 record
+ uint8_t d_rcode; // plain DNS Rcode
+ uint8_t d_extrcode; // upper bits in EDNS0 record
};
class EDNSVersionRule : public DNSRule
{
public:
- EDNSVersionRule(uint8_t version) : d_version(version)
+ EDNSVersionRule(uint8_t version) :
+ d_version(version)
{
}
bool matches(const DNSQuestion* dq) const override
}
string toString() const override
{
- return "ednsversion>"+std::to_string(d_version);
+ return "ednsversion>" + std::to_string(d_version);
}
+
private:
uint8_t d_version;
};
class EDNSOptionRule : public DNSRule
{
public:
- EDNSOptionRule(uint16_t optcode) : d_optcode(optcode)
+ EDNSOptionRule(uint16_t optcode) :
+ d_optcode(optcode)
{
}
bool matches(const DNSQuestion* dq) const override
}
string toString() const override
{
- return "ednsoptcode=="+std::to_string(d_optcode);
+ return "ednsoptcode==" + std::to_string(d_optcode);
}
+
private:
uint16_t d_optcode;
};
class ProbaRule : public DNSRule
{
public:
- ProbaRule(double proba) : d_proba(proba)
+ ProbaRule(double proba) :
+ d_proba(proba)
{
}
bool matches(const DNSQuestion* dq) const override
{
- if(d_proba == 1.0)
+ if (d_proba == 1.0)
return true;
- double rnd = 1.0*dns_random_uint32() / UINT32_MAX;
+ double rnd = 1.0 * dns_random_uint32() / UINT32_MAX;
return rnd > (1.0 - d_proba);
}
string toString() const override
{
return "match with prob. " + (boost::format("%0.2f") % d_proba).str();
}
+
private:
double d_proba;
};
class TagRule : public DNSRule
{
public:
- TagRule(const std::string& tag, boost::optional<std::string> value) : d_value(std::move(value)), d_tag(tag)
+ TagRule(const std::string& tag, boost::optional<std::string> value) :
+ d_value(std::move(value)), d_tag(tag)
{
}
bool matches(const DNSQuestion* dq) const override
return false;
}
- if (!d_value) {
+ if (!d_value || d_value->empty()) {
return true;
}
class PoolAvailableRule : public DNSRule
{
public:
- PoolAvailableRule(const std::string& poolname) : d_poolname(poolname)
+ PoolAvailableRule(const std::string& poolname) :
+ d_poolname(poolname)
{
}
{
return "pool '" + d_poolname + "' is available";
}
+
private:
std::string d_poolname;
};
class PoolOutstandingRule : public DNSRule
{
public:
- PoolOutstandingRule(const std::string& poolname, const size_t limit) : d_poolname(poolname), d_limit(limit)
+ PoolOutstandingRule(const std::string& poolname, const size_t limit) :
+ d_poolname(poolname), d_limit(limit)
{
}
{
return "pool '" + d_poolname + "' outstanding > " + std::to_string(d_limit);
}
+
private:
std::string d_poolname;
size_t d_limit;
};
-class KeyValueStoreLookupRule: public DNSRule
+class KeyValueStoreLookupRule : public DNSRule
{
public:
- KeyValueStoreLookupRule(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey): d_kvs(kvs), d_key(lookupKey)
+ KeyValueStoreLookupRule(const std::shared_ptr<KeyValueStore>& kvs, const std::shared_ptr<KeyValueLookupKey>& lookupKey) :
+ d_kvs(kvs), d_key(lookupKey)
{
}
std::shared_ptr<KeyValueLookupKey> d_key;
};
-class KeyValueStoreRangeLookupRule: public DNSRule
+class KeyValueStoreRangeLookupRule : public DNSRule
{
public:
- KeyValueStoreRangeLookupRule(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey): d_kvs(kvs), d_key(lookupKey)
+ KeyValueStoreRangeLookupRule(const std::shared_ptr<KeyValueStore>& kvs, const std::shared_ptr<KeyValueLookupKey>& lookupKey) :
+ d_kvs(kvs), d_key(lookupKey)
{
}
class LuaRule : public DNSRule
{
public:
- typedef std::function<bool(const DNSQuestion* dq)> func_t;
- LuaRule(const func_t& func): d_func(func)
+ LuaRule(const dnsdist::selectors::LuaSelectorFunction& func) :
+ d_func(func)
{}
bool matches(const DNSQuestion* dq) const override
try {
auto lock = g_lua.lock();
return d_func(dq);
- } catch (const std::exception &e) {
+ }
+ catch (const std::exception& e) {
warnlog("LuaRule failed inside Lua: %s", e.what());
- } catch (...) {
+ }
+ catch (...) {
warnlog("LuaRule failed inside Lua: [unknown exception]");
}
return false;
{
return "Lua script";
}
+
private:
- func_t d_func;
+ dnsdist::selectors::LuaSelectorFunction d_func;
};
class LuaFFIRule : public DNSRule
{
public:
- typedef std::function<bool(dnsdist_ffi_dnsquestion_t* dq)> func_t;
- LuaFFIRule(const func_t& func): d_func(func)
+ LuaFFIRule(const dnsdist::selectors::LuaSelectorFFIFunction& func) :
+ d_func(func)
{}
bool matches(const DNSQuestion* dq) const override
try {
auto lock = g_lua.lock();
return d_func(&dqffi);
- } catch (const std::exception &e) {
+ }
+ catch (const std::exception& e) {
warnlog("LuaFFIRule failed inside Lua: %s", e.what());
- } catch (...) {
+ }
+ catch (...) {
warnlog("LuaFFIRule failed inside Lua: [unknown exception]");
}
return false;
{
return "Lua FFI script";
}
+
private:
- func_t d_func;
+ dnsdist::selectors::LuaSelectorFFIFunction d_func;
};
class LuaFFIPerThreadRule : public DNSRule
{
public:
- typedef std::function<bool(dnsdist_ffi_dnsquestion_t* dq)> func_t;
-
- LuaFFIPerThreadRule(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++)
+ LuaFFIPerThreadRule(const std::string& code) :
+ d_functionCode(code), d_functionID(s_functionsCounter++)
{
}
/* mark the state as initialized first so if there is a syntax error
we only try to execute the code once */
state.d_initialized = true;
- state.d_func = state.d_luaContext.executeCode<func_t>(d_functionCode);
+ state.d_func = state.d_luaContext.executeCode<dnsdist::selectors::LuaSelectorFFIFunction>(d_functionCode);
}
if (!state.d_func) {
dnsdist_ffi_dnsquestion_t dqffi(const_cast<DNSQuestion*>(dq));
return state.d_func(&dqffi);
}
- catch (const std::exception &e) {
+ catch (const std::exception& e) {
warnlog("LuaFFIPerthreadRule failed inside Lua: %s", e.what());
}
catch (...) {
{
return "Lua FFI per-thread script";
}
+
private:
struct PerThreadState
{
LuaContext d_luaContext;
- func_t d_func;
+ dnsdist::selectors::LuaSelectorFFIFunction d_func;
bool d_initialized{false};
};
class ProxyProtocolValueRule : public DNSRule
{
public:
- ProxyProtocolValueRule(uint8_t type, boost::optional<std::string> value): d_value(std::move(value)), d_type(type)
+ ProxyProtocolValueRule(uint8_t type, boost::optional<std::string> value) :
+ d_value(std::move(value)), d_type(type)
{
}
}
for (const auto& entry : *dq->proxyProtocolValues) {
- if (entry.type == d_type && (!d_value || entry.content == *d_value)) {
+ if (entry.type == d_type && (!d_value || d_value->empty() || entry.content == *d_value)) {
return true;
}
}
class PayloadSizeRule : public DNSRule
{
- enum class Comparisons : uint8_t { equal, greater, greaterOrEqual, smaller, smallerOrEqual };
+ enum class Comparisons : uint8_t
+ {
+ equal,
+ greater,
+ greaterOrEqual,
+ smaller,
+ smallerOrEqual
+ };
+
public:
- PayloadSizeRule(const std::string& comparison, uint16_t size): d_size(size)
+ PayloadSizeRule(const std::string& comparison, uint16_t size) :
+ d_size(size)
{
if (comparison == "equal") {
d_comparison = Comparisons::equal;
string toString() const override
{
static const std::array<const std::string, 5> comparisonStr{
- "equal to" ,
+ "equal to",
"greater than",
"equal to or greater than",
"smaller than",
- "equal to or smaller than"
- };
+ "equal to or smaller than"};
return "payload size is " + comparisonStr.at(static_cast<size_t>(d_comparison)) + " " + std::to_string(d_size);
}
uint16_t d_size;
Comparisons d_comparison;
};
+
+namespace dnsdist::selectors
+{
+std::shared_ptr<AndRule> getAndSelector(const std::vector<std::shared_ptr<DNSRule>>& rules);
+std::shared_ptr<OrRule> getOrSelector(const std::vector<std::shared_ptr<DNSRule>>& rules);
+std::shared_ptr<NotRule> getNotSelector(const std::shared_ptr<DNSRule>& rule);
+std::shared_ptr<QNameRule> getQNameSelector(const DNSName& qname);
+std::shared_ptr<QNameSetRule> getQNameSetSelector(const DNSNameSet& qnames);
+std::shared_ptr<SuffixMatchNodeRule> getQNameSuffixSelector(const SuffixMatchNode& suffixes, bool quiet);
+std::shared_ptr<QTypeRule> getQTypeSelector(const std::string& qtypeStr, uint16_t qtypeCode);
+std::shared_ptr<QClassRule> getQClassSelector(const std::string& qclassStr, uint16_t qclassCode);
+std::shared_ptr<NetmaskGroupRule> getNetmaskGroupSelector(const NetmaskGroup& nmg, bool source, bool quiet);
+std::shared_ptr<KeyValueStoreLookupRule> getKeyValueStoreLookupSelector(const std::shared_ptr<KeyValueStore>& kvs, const std::shared_ptr<KeyValueLookupKey>& lookupKey);
+std::shared_ptr<KeyValueStoreRangeLookupRule> getKeyValueStoreRangeLookupSelector(const std::shared_ptr<KeyValueStore>& kvs, const std::shared_ptr<KeyValueLookupKey>& lookupKey);
+
+#include "dnsdist-selectors-factory-generated.hh"
+}
{
return d_visual;
}
+
+namespace dnsdist::selectors
+{
+std::shared_ptr<QClassRule> getQClassSelector(const std::string& qclassStr, uint16_t qclassCode)
+{
+ QClass qclass(qclassCode);
+ if (!qclassStr.empty()) {
+ qclass = QClass(std::string(qclassStr));
+ }
+
+ return std::make_shared<QClassRule>(qclass);
+}
+
+std::shared_ptr<QTypeRule> getQTypeSelector(const std::string& qtypeStr, uint16_t qtypeCode)
+{
+ QType qtype(qtypeCode);
+ if (!qtypeStr.empty()) {
+ qtype = std::string(qtypeStr);
+ }
+
+ return std::make_shared<QTypeRule>(qtype);
+}
+
+std::shared_ptr<SuffixMatchNodeRule> getQNameSuffixSelector(const SuffixMatchNode& suffixes, bool quiet)
+{
+ return std::make_shared<SuffixMatchNodeRule>(suffixes, quiet);
+}
+
+std::shared_ptr<QNameSetRule> getQNameSetSelector(const DNSNameSet& qnames)
+{
+ return std::make_shared<QNameSetRule>(qnames);
+}
+
+std::shared_ptr<QNameRule> getQNameSelector(const DNSName& qname)
+{
+ return std::make_shared<QNameRule>(qname);
+}
+
+std::shared_ptr<NetmaskGroupRule> getNetmaskGroupSelector(const NetmaskGroup& nmg, bool source, bool quiet)
+{
+ return std::make_shared<NetmaskGroupRule>(nmg, source, quiet);
+}
+
+std::shared_ptr<KeyValueStoreLookupRule> getKeyValueStoreLookupSelector(const std::shared_ptr<KeyValueStore>& kvs, const std::shared_ptr<KeyValueLookupKey>& lookupKey)
+{
+ return std::make_shared<KeyValueStoreLookupRule>(kvs, lookupKey);
+}
+
+std::shared_ptr<KeyValueStoreRangeLookupRule> getKeyValueStoreRangeLookupSelector(const std::shared_ptr<KeyValueStore>& kvs, const std::shared_ptr<KeyValueLookupKey>& lookupKey)
+{
+ return std::make_shared<KeyValueStoreRangeLookupRule>(kvs, lookupKey);
+}
+
+std::shared_ptr<AndRule> getAndSelector(const std::vector<std::shared_ptr<DNSRule>>& rules)
+{
+ return std::make_shared<AndRule>(rules);
+}
+
+std::shared_ptr<OrRule> getOrSelector(const std::vector<std::shared_ptr<DNSRule>>& rules)
+{
+ return std::make_shared<OrRule>(rules);
+}
+
+std::shared_ptr<NotRule> getNotSelector(const std::shared_ptr<DNSRule>& rule)
+{
+ return std::make_shared<NotRule>(rule);
+}
+
+// NOLINTNEXTLINE(bugprone-suspicious-include)
+#include "dnsdist-selectors-factory-generated.cc"
+
+}