From: Remi Gacogne Date: Tue, 24 Dec 2024 15:33:37 +0000 (+0100) Subject: dnsdist: Actually move selectors and actions to proper factories X-Git-Tag: dnsdist-2.0.0-alpha1~160^2~30 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=d5ad413a3069181834c4188e0787d8d14cb7982c;p=thirdparty%2Fpdns.git dnsdist: Actually move selectors and actions to proper factories --- diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index f7ca43fab0..3051932f08 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -172,6 +172,7 @@ dnsdist_SOURCES = \ dns_random.hh \ dnscrypt.cc dnscrypt.hh \ dnsdist-actions-factory-generated.hh \ + dnsdist-actions-factory.cc dnsdist-actions-factory.hh \ dnsdist-actions.cc dnsdist-actions.hh \ dnsdist-async.cc dnsdist-async.hh \ dnsdist-backend.cc dnsdist-backend.hh \ @@ -454,6 +455,7 @@ endif if HAVE_RE2 dnsdist_LDADD += $(RE2_LIBS) +testrunner_LDADD += $(RE2_LIBS) endif if HAVE_LIBSSL diff --git a/pdns/dnsdistdist/dnsdist-actions-factory.cc b/pdns/dnsdistdist/dnsdist-actions-factory.cc new file mode 100644 index 0000000000..052d08f7c3 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-actions-factory.cc @@ -0,0 +1,2444 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#include + +#include "dnsdist-actions-factory.hh" + +#include "config.h" +#include "dnsdist.hh" +#include "dnsdist-async.hh" +#include "dnsdist-dnsparser.hh" +#include "dnsdist-ecs.hh" +#include "dnsdist-edns.hh" +#include "dnsdist-lua.hh" +#include "dnsdist-lua-ffi.hh" +#include "dnsdist-mac-address.hh" +#include "dnsdist-protobuf.hh" +#include "dnsdist-proxy-protocol.hh" +#include "dnsdist-kvs.hh" +#include "dnsdist-rule-chains.hh" +#include "dnsdist-self-answers.hh" +#include "dnsdist-snmp.hh" + +#include "dnstap.hh" +#include "dnswriter.hh" +#include "ednsoptions.hh" +#include "fstrm_logger.hh" +#include "ipcipher.hh" +#include "remote_logger.hh" +#include "svc-records.hh" +#include "threadname.hh" + +namespace dnsdist::actions +{ +class DropAction : public DNSAction +{ +public: + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + return Action::Drop; + } + [[nodiscard]] std::string toString() const override + { + return "drop"; + } +}; + +class AllowAction : public DNSAction +{ +public: + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + return Action::Allow; + } + [[nodiscard]] std::string toString() const override + { + return "allow"; + } +}; + +class NoneAction : public DNSAction +{ +public: + // this action does not stop the processing + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "no op"; + } +}; + +class QPSAction : public DNSAction +{ +public: + QPSAction(int limit) : + d_qps(QPSLimiter(limit, limit)) + { + } + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + if (d_qps.lock()->check()) { + return Action::None; + } + return Action::Drop; + } + [[nodiscard]] std::string toString() const override + { + return "qps limit to " + std::to_string(d_qps.lock()->getRate()); + } + +private: + mutable LockGuarded d_qps; +}; + +class DelayAction : public DNSAction +{ +public: + DelayAction(int msec) : + d_msec(msec) + { + } + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + *ruleresult = std::to_string(d_msec); + return Action::Delay; + } + [[nodiscard]] std::string toString() const override + { + return "delay by " + std::to_string(d_msec) + " ms"; + } + +private: + int d_msec; +}; + +class TeeAction : public DNSAction +{ +public: + // this action does not stop the processing + TeeAction(const ComboAddress& rca, const std::optional& lca, bool addECS = false, bool addProxyProtocol = false); + TeeAction(TeeAction& other) = delete; + TeeAction(TeeAction&& other) = delete; + TeeAction& operator=(TeeAction& other) = delete; + TeeAction& operator=(TeeAction&& other) = delete; + ~TeeAction() override; + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override; + [[nodiscard]] std::string toString() const override; + std::map getStats() const override; + +private: + void worker(); + + ComboAddress d_remote; + std::thread d_worker; + Socket d_socket; + mutable std::atomic d_senderrors{0}; + unsigned long d_recverrors{0}; + mutable std::atomic d_queries{0}; + stat_t d_responses{0}; + stat_t d_nxdomains{0}; + stat_t d_servfails{0}; + stat_t d_refuseds{0}; + stat_t d_formerrs{0}; + stat_t d_notimps{0}; + stat_t d_noerrors{0}; + mutable stat_t d_tcpdrops{0}; + stat_t d_otherrcode{0}; + std::atomic d_pleaseQuit{false}; + bool d_addECS{false}; + bool d_addProxyProtocol{false}; +}; + +TeeAction::TeeAction(const ComboAddress& rca, const std::optional& lca, bool addECS, bool addProxyProtocol) : + d_remote(rca), d_socket(d_remote.sin4.sin_family, SOCK_DGRAM, 0), d_addECS(addECS), d_addProxyProtocol(addProxyProtocol) +{ + if (lca) { + d_socket.bind(*lca, false); + } + d_socket.connect(d_remote); + d_socket.setNonBlocking(); + d_worker = std::thread([this]() { + worker(); + }); +} + +TeeAction::~TeeAction() +{ + d_pleaseQuit = true; + close(d_socket.releaseHandle()); + d_worker.join(); +} + +DNSAction::Action TeeAction::operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const +{ + if (dnsquestion->overTCP()) { + d_tcpdrops++; + return DNSAction::Action::None; + } + + d_queries++; + + PacketBuffer query; + if (d_addECS) { + query = dnsquestion->getData(); + bool ednsAdded = false; + bool ecsAdded = false; + + std::string newECSOption; + generateECSOption(dnsquestion->ecs ? dnsquestion->ecs->getNetwork() : dnsquestion->ids.origRemote, newECSOption, dnsquestion->ecs ? dnsquestion->ecs->getBits() : dnsquestion->ecsPrefixLength); + + if (!handleEDNSClientSubnet(query, dnsquestion->getMaximumSize(), dnsquestion->ids.qname.wirelength(), ednsAdded, ecsAdded, dnsquestion->ecsOverride, newECSOption)) { + return DNSAction::Action::None; + } + } + + if (d_addProxyProtocol) { + auto proxyPayload = getProxyProtocolPayload(*dnsquestion); + if (query.empty()) { + query = dnsquestion->getData(); + } + if (!addProxyProtocol(query, proxyPayload)) { + return DNSAction::Action::None; + } + } + + { + const PacketBuffer& payload = query.empty() ? dnsquestion->getData() : query; + auto res = send(d_socket.getHandle(), payload.data(), payload.size(), 0); + + if (res <= 0) { + d_senderrors++; + } + } + + return DNSAction::Action::None; +} + +std::string TeeAction::toString() const +{ + return "tee to " + d_remote.toStringWithPort(); +} + +std::map TeeAction::getStats() const +{ + return {{"queries", d_queries}, + {"responses", d_responses}, + {"recv-errors", d_recverrors}, + {"send-errors", d_senderrors}, + {"noerrors", d_noerrors}, + {"nxdomains", d_nxdomains}, + {"refuseds", d_refuseds}, + {"servfails", d_servfails}, + {"other-rcode", d_otherrcode}, + {"tcp-drops", d_tcpdrops}}; +} + +void TeeAction::worker() +{ + setThreadName("dnsdist/TeeWork"); + std::array packet{}; + ssize_t res = 0; + const dnsheader_aligned dnsheader(packet.data()); + for (;;) { + res = waitForData(d_socket.getHandle(), 0, 250000); + if (d_pleaseQuit) { + break; + } + + if (res < 0) { + usleep(250000); + continue; + } + if (res == 0) { + continue; + } + res = recv(d_socket.getHandle(), packet.data(), packet.size(), 0); + if (static_cast(res) <= sizeof(struct dnsheader)) { + d_recverrors++; + } + else { + d_responses++; + } + + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well + if (dnsheader->rcode == RCode::NoError) { + d_noerrors++; + } + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well + else if (dnsheader->rcode == RCode::ServFail) { + d_servfails++; + } + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well + else if (dnsheader->rcode == RCode::NXDomain) { + d_nxdomains++; + } + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well + else if (dnsheader->rcode == RCode::Refused) { + d_refuseds++; + } + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well + else if (dnsheader->rcode == RCode::FormErr) { + d_formerrs++; + } + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well + else if (dnsheader->rcode == RCode::NotImp) { + d_notimps++; + } + } +} + +class PoolAction : public DNSAction +{ +public: + PoolAction(std::string pool, bool stopProcessing) : + d_pool(std::move(pool)), d_stopProcessing(stopProcessing) {} + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + if (d_stopProcessing) { + /* we need to do it that way to keep compatiblity with custom Lua actions returning DNSAction.Pool, 'poolname' */ + *ruleresult = d_pool; + return Action::Pool; + } + dnsquestion->ids.poolName = d_pool; + return Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "to pool " + d_pool; + } + +private: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::string d_pool; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const bool d_stopProcessing; +}; + +class QPSPoolAction : public DNSAction +{ +public: + QPSPoolAction(unsigned int limit, std::string pool, bool stopProcessing) : + d_qps(QPSLimiter(limit, limit)), d_pool(std::move(pool)), d_stopProcessing(stopProcessing) {} + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + if (d_qps.lock()->check()) { + if (d_stopProcessing) { + /* we need to do it that way to keep compatiblity with custom Lua actions returning DNSAction.Pool, 'poolname' */ + *ruleresult = d_pool; + return Action::Pool; + } + dnsquestion->ids.poolName = d_pool; + } + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "max " + std::to_string(d_qps.lock()->getRate()) + " to pool " + d_pool; + } + +private: + mutable LockGuarded d_qps; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::string d_pool; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const bool d_stopProcessing; +}; + +class RCodeAction : public DNSAction +{ +public: + RCodeAction(uint8_t rcode, const dnsdist::ResponseConfig& responseConfig) : + d_responseConfig(responseConfig), d_rcode(rcode) {} + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { + header.rcode = d_rcode; + header.qr = true; // for good measure + setResponseHeadersFromConfig(header, d_responseConfig); + return true; + }); + return Action::HeaderModify; + } + [[nodiscard]] std::string toString() const override + { + return "set rcode " + std::to_string(d_rcode); + } + +private: + dnsdist::ResponseConfig d_responseConfig; + uint8_t d_rcode; +}; + +class ERCodeAction : public DNSAction +{ +public: + ERCodeAction(uint8_t rcode, dnsdist::ResponseConfig responseConfig) : + d_responseConfig(std::move(responseConfig)), d_rcode(rcode) + { + } + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { + header.rcode = (d_rcode & 0xF); + header.qr = true; // for good measure + setResponseHeadersFromConfig(header, d_responseConfig); + return true; + }); + dnsquestion->ednsRCode = ((d_rcode & 0xFFF0) >> 4); + return Action::HeaderModify; + } + [[nodiscard]] std::string toString() const override + { + return "set ercode " + ERCode::to_s(d_rcode); + } + +private: + dnsdist::ResponseConfig d_responseConfig; + uint8_t d_rcode; +}; + +class SpoofSVCAction : public DNSAction +{ +public: + SpoofSVCAction(const std::vector& parameters, const dnsdist::ResponseConfig& responseConfig) : + d_responseConfig(responseConfig) + { + d_payloads.reserve(parameters.size()); + + for (const auto& param : parameters) { + std::vector payload; + if (!generateSVCPayload(payload, param)) { + throw std::runtime_error("Unable to generate a valid SVC record from the supplied parameters"); + } + + d_payloads.push_back(std::move(payload)); + + for (const auto& hint : param.ipv4hints) { + d_additionals4.insert({param.target, ComboAddress(hint)}); + } + + for (const auto& hint : param.ipv6hints) { + d_additionals6.insert({param.target, ComboAddress(hint)}); + } + } + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + if (!dnsdist::svc::generateSVCResponse(*dnsquestion, d_payloads, d_additionals4, d_additionals6, d_responseConfig)) { + return Action::None; + } + + return Action::HeaderModify; + } + + [[nodiscard]] std::string toString() const override + { + return "spoof SVC record "; + } + +private: + dnsdist::ResponseConfig d_responseConfig; + std::vector> d_payloads{}; + std::set> d_additionals4{}; + std::set> d_additionals6{}; +}; + +class TCAction : public DNSAction +{ +public: + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + return Action::Truncate; + } + [[nodiscard]] std::string toString() const override + { + return "tc=1 answer"; + } +}; + +class TCResponseAction : public DNSResponseAction +{ +public: + DNSResponseAction::Action operator()(DNSResponse* dnsResponse, std::string* ruleresult) const override + { + return Action::Truncate; + } + [[nodiscard]] std::string toString() const override + { + return "tc=1 answer"; + } +}; + +class LuaAction : public DNSAction +{ +public: + LuaAction(LuaActionFunction func) : + d_func(std::move(func)) + {} + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + try { + DNSAction::Action result{}; + { + auto lock = g_lua.lock(); + auto ret = d_func(dnsquestion); + if (ruleresult != nullptr) { + if (boost::optional rule = std::get<1>(ret)) { + *ruleresult = *rule; + } + else { + // default to empty string + ruleresult->clear(); + } + } + result = static_cast(std::get<0>(ret)); + } + dnsdist::handleQueuedAsynchronousEvents(); + return result; + } + catch (const std::exception& e) { + warnlog("LuaAction failed inside Lua, returning ServFail: %s", e.what()); + } + catch (...) { + warnlog("LuaAction failed inside Lua, returning ServFail: [unknown exception]"); + } + return DNSAction::Action::ServFail; + } + + [[nodiscard]] std::string toString() const override + { + return "Lua script"; + } + +private: + LuaActionFunction d_func; +}; + +class LuaResponseAction : public DNSResponseAction +{ +public: + LuaResponseAction(LuaResponseActionFunction func) : + d_func(std::move(func)) + {} + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + try { + DNSResponseAction::Action result{}; + { + auto lock = g_lua.lock(); + auto ret = d_func(response); + if (ruleresult != nullptr) { + if (boost::optional rule = std::get<1>(ret)) { + *ruleresult = *rule; + } + else { + // default to empty string + ruleresult->clear(); + } + } + result = static_cast(std::get<0>(ret)); + } + dnsdist::handleQueuedAsynchronousEvents(); + return result; + } + catch (const std::exception& e) { + warnlog("LuaResponseAction failed inside Lua, returning ServFail: %s", e.what()); + } + catch (...) { + warnlog("LuaResponseAction failed inside Lua, returning ServFail: [unknown exception]"); + } + return DNSResponseAction::Action::ServFail; + } + + [[nodiscard]] std::string toString() const override + { + return "Lua response script"; + } + +private: + LuaResponseActionFunction d_func; +}; + +class LuaFFIAction : public DNSAction +{ +public: + LuaFFIAction(LuaActionFFIFunction func) : + d_func(std::move(func)) + { + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + dnsdist_ffi_dnsquestion_t dqffi(dnsquestion); + try { + DNSAction::Action result{}; + { + auto lock = g_lua.lock(); + auto ret = d_func(&dqffi); + if (ruleresult != nullptr) { + if (dqffi.result) { + *ruleresult = *dqffi.result; + } + else { + // default to empty string + ruleresult->clear(); + } + } + result = static_cast(ret); + } + dnsdist::handleQueuedAsynchronousEvents(); + return result; + } + catch (const std::exception& e) { + warnlog("LuaFFIAction failed inside Lua, returning ServFail: %s", e.what()); + } + catch (...) { + warnlog("LuaFFIAction failed inside Lua, returning ServFail: [unknown exception]"); + } + return DNSAction::Action::ServFail; + } + + [[nodiscard]] std::string toString() const override + { + return "Lua FFI script"; + } + +private: + LuaActionFFIFunction d_func; +}; + +class LuaFFIPerThreadAction : public DNSAction +{ +public: + LuaFFIPerThreadAction(std::string code) : + d_functionCode(std::move(code)), d_functionID(s_functionsCounter++) + { + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + try { + auto& state = t_perThreadStates[d_functionID]; + if (!state.d_initialized) { + setupLuaFFIPerThreadContext(state.d_luaContext); + /* mark the state as initialized first so if there is a syntax error + we only try to execute the code once */ + state.d_initialized = true; + state.d_func = state.d_luaContext.executeCode(d_functionCode); + } + + if (!state.d_func) { + /* the function was not properly initialized */ + return DNSAction::Action::None; + } + + dnsdist_ffi_dnsquestion_t dqffi(dnsquestion); + auto ret = state.d_func(&dqffi); + if (ruleresult != nullptr) { + if (dqffi.result) { + *ruleresult = *dqffi.result; + } + else { + // default to empty string + ruleresult->clear(); + } + } + dnsdist::handleQueuedAsynchronousEvents(); + return static_cast(ret); + } + catch (const std::exception& e) { + warnlog("LuaFFIPerThreadAction failed inside Lua, returning ServFail: %s", e.what()); + } + catch (...) { + warnlog("LuaFFIPerthreadAction failed inside Lua, returning ServFail: [unknown exception]"); + } + return DNSAction::Action::ServFail; + } + + [[nodiscard]] std::string toString() const override + { + return "Lua FFI per-thread script"; + } + +private: + struct PerThreadState + { + LuaContext d_luaContext; + LuaActionFFIFunction d_func; + bool d_initialized{false}; + }; + static std::atomic s_functionsCounter; + static thread_local std::map t_perThreadStates; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::string d_functionCode; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const uint64_t d_functionID; +}; + +std::atomic LuaFFIPerThreadAction::s_functionsCounter = 0; +thread_local std::map LuaFFIPerThreadAction::t_perThreadStates; + +class LuaFFIResponseAction : public DNSResponseAction +{ +public: + LuaFFIResponseAction(LuaResponseActionFFIFunction func) : + d_func(std::move(func)) + { + } + + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + dnsdist_ffi_dnsresponse_t ffiResponse(response); + try { + DNSResponseAction::Action result{}; + { + auto lock = g_lua.lock(); + auto ret = d_func(&ffiResponse); + if (ruleresult != nullptr) { + if (ffiResponse.result) { + *ruleresult = *ffiResponse.result; + } + else { + // default to empty string + ruleresult->clear(); + } + } + result = static_cast(ret); + } + dnsdist::handleQueuedAsynchronousEvents(); + return result; + } + catch (const std::exception& e) { + warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: %s", e.what()); + } + catch (...) { + warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: [unknown exception]"); + } + return DNSResponseAction::Action::ServFail; + } + + [[nodiscard]] std::string toString() const override + { + return "Lua FFI script"; + } + +private: + LuaResponseActionFFIFunction d_func; +}; + +class LuaFFIPerThreadResponseAction : public DNSResponseAction +{ +public: + LuaFFIPerThreadResponseAction(std::string code) : + d_functionCode(std::move(code)), d_functionID(s_functionsCounter++) + { + } + + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + try { + auto& state = t_perThreadStates[d_functionID]; + if (!state.d_initialized) { + setupLuaFFIPerThreadContext(state.d_luaContext); + /* mark the state as initialized first so if there is a syntax error + we only try to execute the code once */ + state.d_initialized = true; + state.d_func = state.d_luaContext.executeCode(d_functionCode); + } + + if (!state.d_func) { + /* the function was not properly initialized */ + return DNSResponseAction::Action::None; + } + + dnsdist_ffi_dnsresponse_t ffiResponse(response); + auto ret = state.d_func(&ffiResponse); + if (ruleresult != nullptr) { + if (ffiResponse.result) { + *ruleresult = *ffiResponse.result; + } + else { + // default to empty string + ruleresult->clear(); + } + } + dnsdist::handleQueuedAsynchronousEvents(); + return static_cast(ret); + } + catch (const std::exception& e) { + warnlog("LuaFFIPerThreadResponseAction failed inside Lua, returning ServFail: %s", e.what()); + } + catch (...) { + warnlog("LuaFFIPerthreadResponseAction failed inside Lua, returning ServFail: [unknown exception]"); + } + return DNSResponseAction::Action::ServFail; + } + + [[nodiscard]] std::string toString() const override + { + return "Lua FFI per-thread script"; + } + +private: + struct PerThreadState + { + LuaContext d_luaContext; + LuaResponseActionFFIFunction d_func; + bool d_initialized{false}; + }; + + static std::atomic s_functionsCounter; + static thread_local std::map t_perThreadStates; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::string d_functionCode; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const uint64_t d_functionID; +}; + +std::atomic LuaFFIPerThreadResponseAction::s_functionsCounter = 0; +thread_local std::map LuaFFIPerThreadResponseAction::t_perThreadStates; + +class SpoofAction : public DNSAction +{ +public: + SpoofAction(const vector& addrs, const dnsdist::ResponseConfig& responseConfig) : + d_responseConfig(responseConfig), d_addrs(addrs) + { + for (const auto& addr : d_addrs) { + if (addr.isIPv4()) { + d_types.insert(QType::A); + } + else if (addr.isIPv6()) { + d_types.insert(QType::AAAA); + } + } + + if (!d_addrs.empty()) { + d_types.insert(QType::ANY); + } + } + + SpoofAction(const DNSName& cname, const dnsdist::ResponseConfig& responseConfig) : + d_responseConfig(responseConfig), d_cname(cname) + { + } + + SpoofAction(const PacketBuffer& rawresponse) : + d_raw(rawresponse) + { + } + + SpoofAction(const vector& raws, std::optional typeForAny, const dnsdist::ResponseConfig& responseConfig) : + d_responseConfig(responseConfig), d_rawResponses(raws), d_rawTypeForAny(typeForAny) + { + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, string* ruleresult) const override; + + string toString() const override + { + string ret = "spoof in "; + if (!d_cname.empty()) { + ret += d_cname.toString() + " "; + } + if (d_rawResponses.size() > 0) { + ret += "raw bytes "; + } + else { + for (const auto& a : d_addrs) + ret += a.toString() + " "; + } + return ret; + } + +private: + dnsdist::ResponseConfig d_responseConfig; + std::vector d_addrs; + std::unordered_set d_types; + std::vector d_rawResponses; + PacketBuffer d_raw; + DNSName d_cname; + std::optional d_rawTypeForAny{}; +}; + +DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const +{ + uint16_t qtype = dnsquestion->ids.qtype; + // do we even have a response? + if (d_cname.empty() && d_rawResponses.empty() && + // make sure pre-forged response is greater than sizeof(dnsheader) + (d_raw.size() < sizeof(dnsheader)) && d_types.count(qtype) == 0) { + return Action::None; + } + + if (d_raw.size() >= sizeof(dnsheader)) { + dnsdist::self_answers::generateAnswerFromRawPacket(*dnsquestion, d_raw); + return Action::HeaderModify; + } + + if (!d_cname.empty()) { + if (dnsdist::self_answers::generateAnswerFromCNAME(*dnsquestion, d_cname, d_responseConfig)) { + return Action::HeaderModify; + } + } + else if (!d_rawResponses.empty()) { + if (dnsdist::self_answers::generateAnswerFromRDataEntries(*dnsquestion, d_rawResponses, d_rawTypeForAny, d_responseConfig)) { + return Action::HeaderModify; + } + } + else { + if (dnsdist::self_answers::generateAnswerFromIPAddresses(*dnsquestion, d_addrs, d_responseConfig)) { + return Action::HeaderModify; + } + } + + return Action::None; +} + +class SetMacAddrAction : public DNSAction +{ +public: + // this action does not stop the processing + SetMacAddrAction(uint16_t code) : + d_code(code) + { + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + dnsdist::MacAddress mac{}; + int res = dnsdist::MacAddressesCache::get(dnsquestion->ids.origRemote, mac.data(), mac.size()); + if (res != 0) { + return Action::None; + } + + std::string optRData; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + generateEDNSOption(d_code, reinterpret_cast(mac.data()), optRData); + + if (dnsquestion->getHeader()->arcount > 0) { + bool ednsAdded = false; + bool optionAdded = false; + PacketBuffer newContent; + newContent.reserve(dnsquestion->getData().size()); + + if (!slowRewriteEDNSOptionInQueryWithRecords(dnsquestion->getData(), newContent, ednsAdded, d_code, optionAdded, true, optRData)) { + return Action::None; + } + + if (newContent.size() > dnsquestion->getMaximumSize()) { + return Action::None; + } + + dnsquestion->getMutableData() = std::move(newContent); + if (!dnsquestion->ids.ednsAdded && ednsAdded) { + dnsquestion->ids.ednsAdded = true; + } + + return Action::None; + } + + auto& data = dnsquestion->getMutableData(); + if (generateOptRR(optRData, data, dnsquestion->getMaximumSize(), dnsdist::configuration::s_EdnsUDPPayloadSize, 0, false)) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) { + header.arcount = htons(1); + return true; + }); + // make sure that any EDNS sent by the backend is removed before forwarding the response to the client + dnsquestion->ids.ednsAdded = true; + } + + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "add EDNS MAC (code=" + std::to_string(d_code) + ")"; + } + +private: + uint16_t d_code{3}; +}; + +class SetEDNSOptionAction : public DNSAction +{ +public: + // this action does not stop the processing + SetEDNSOptionAction(uint16_t code, std::string data) : + d_code(code), d_data(std::move(data)) + { + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + setEDNSOption(*dnsquestion, d_code, d_data); + return Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "add EDNS Option (code=" + std::to_string(d_code) + ")"; + } + +private: + uint16_t d_code; + std::string d_data; +}; + +class SetNoRecurseAction : public DNSAction +{ +public: + // this action does not stop the processing + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) { + header.rd = false; + return true; + }); + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "set rd=0"; + } +}; + +class LogAction : public DNSAction, public boost::noncopyable +{ +public: + // this action does not stop the processing + LogAction() = default; + + LogAction(const std::string& str, bool binary = true, bool append = false, bool buffered = true, bool verboseOnly = true, bool includeTimestamp = false) : + d_fname(str), d_binary(binary), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp), d_append(append), d_buffered(buffered) + { + if (str.empty()) { + return; + } + + if (!reopenLogFile()) { + throw std::runtime_error("Unable to open file '" + str + "' for logging: " + stringerror()); + } + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + auto filepointer = std::atomic_load_explicit(&d_fp, std::memory_order_acquire); + if (!filepointer) { + if (!d_verboseOnly || dnsdist::configuration::getCurrentRuntimeConfiguration().d_verbose) { + if (d_includeTimestamp) { + infolog("[%u.%u] Packet from %s for %s %s with id %d", static_cast(dnsquestion->getQueryRealTime().tv_sec), static_cast(dnsquestion->getQueryRealTime().tv_nsec), dnsquestion->ids.origRemote.toStringWithPort(), dnsquestion->ids.qname.toString(), QType(dnsquestion->ids.qtype).toString(), dnsquestion->getHeader()->id); + } + else { + infolog("Packet from %s for %s %s with id %d", dnsquestion->ids.origRemote.toStringWithPort(), dnsquestion->ids.qname.toString(), QType(dnsquestion->ids.qtype).toString(), dnsquestion->getHeader()->id); + } + } + } + else { + if (d_binary) { + const auto& out = dnsquestion->ids.qname.getStorage(); + if (d_includeTimestamp) { + auto tv_sec = static_cast(dnsquestion->getQueryRealTime().tv_sec); + auto tv_nsec = static_cast(dnsquestion->getQueryRealTime().tv_nsec); + fwrite(&tv_sec, sizeof(tv_sec), 1, filepointer.get()); + fwrite(&tv_nsec, sizeof(tv_nsec), 1, filepointer.get()); + } + uint16_t queryId = dnsquestion->getHeader()->id; + fwrite(&queryId, sizeof(queryId), 1, filepointer.get()); + fwrite(out.c_str(), 1, out.size(), filepointer.get()); + fwrite(&dnsquestion->ids.qtype, sizeof(dnsquestion->ids.qtype), 1, filepointer.get()); + fwrite(&dnsquestion->ids.origRemote.sin4.sin_family, sizeof(dnsquestion->ids.origRemote.sin4.sin_family), 1, filepointer.get()); + if (dnsquestion->ids.origRemote.sin4.sin_family == AF_INET) { + fwrite(&dnsquestion->ids.origRemote.sin4.sin_addr.s_addr, sizeof(dnsquestion->ids.origRemote.sin4.sin_addr.s_addr), 1, filepointer.get()); + } + else if (dnsquestion->ids.origRemote.sin4.sin_family == AF_INET6) { + fwrite(&dnsquestion->ids.origRemote.sin6.sin6_addr.s6_addr, sizeof(dnsquestion->ids.origRemote.sin6.sin6_addr.s6_addr), 1, filepointer.get()); + } + fwrite(&dnsquestion->ids.origRemote.sin4.sin_port, sizeof(dnsquestion->ids.origRemote.sin4.sin_port), 1, filepointer.get()); + } + else { + if (d_includeTimestamp) { + fprintf(filepointer.get(), "[%llu.%lu] Packet from %s for %s %s with id %u\n", static_cast(dnsquestion->getQueryRealTime().tv_sec), static_cast(dnsquestion->getQueryRealTime().tv_nsec), dnsquestion->ids.origRemote.toStringWithPort().c_str(), dnsquestion->ids.qname.toString().c_str(), QType(dnsquestion->ids.qtype).toString().c_str(), dnsquestion->getHeader()->id); + } + else { + fprintf(filepointer.get(), "Packet from %s for %s %s with id %u\n", dnsquestion->ids.origRemote.toStringWithPort().c_str(), dnsquestion->ids.qname.toString().c_str(), QType(dnsquestion->ids.qtype).toString().c_str(), dnsquestion->getHeader()->id); + } + } + } + return Action::None; + } + + [[nodiscard]] std::string toString() const override + { + if (!d_fname.empty()) { + return "log to " + d_fname; + } + return "log"; + } + + void reload() override + { + if (!reopenLogFile()) { + warnlog("Unable to open file '%s' for logging: %s", d_fname, stringerror()); + } + } + +private: + bool reopenLogFile() + { + // we are using a naked pointer here because we don't want fclose to be called + // with a nullptr, which would happen if we constructor a shared_ptr with fclose + // as a custom deleter and nullptr as a FILE* + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) + auto* nfp = fopen(d_fname.c_str(), d_append ? "a+" : "w"); + if (nfp == nullptr) { + /* don't fall on our sword when reopening */ + return false; + } + + auto filepointer = std::shared_ptr(nfp, fclose); + nfp = nullptr; + + if (!d_buffered) { + setbuf(filepointer.get(), nullptr); + } + + std::atomic_store_explicit(&d_fp, std::move(filepointer), std::memory_order_release); + return true; + } + + std::string d_fname; + std::shared_ptr d_fp{nullptr}; + bool d_binary{true}; + bool d_verboseOnly{true}; + bool d_includeTimestamp{false}; + bool d_append{false}; + bool d_buffered{true}; +}; + +class LogResponseAction : public DNSResponseAction, public boost::noncopyable +{ +public: + LogResponseAction() = default; + + LogResponseAction(const std::string& str, bool append = false, bool buffered = true, bool verboseOnly = true, bool includeTimestamp = false) : + d_fname(str), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp), d_append(append), d_buffered(buffered) + { + if (str.empty()) { + return; + } + + if (!reopenLogFile()) { + throw std::runtime_error("Unable to open file '" + str + "' for logging: " + stringerror()); + } + } + + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + auto filepointer = std::atomic_load_explicit(&d_fp, std::memory_order_acquire); + if (!filepointer) { + if (!d_verboseOnly || dnsdist::configuration::getCurrentRuntimeConfiguration().d_verbose) { + if (d_includeTimestamp) { + infolog("[%u.%u] Answer to %s for %s %s (%s) with id %u", static_cast(response->getQueryRealTime().tv_sec), static_cast(response->getQueryRealTime().tv_nsec), response->ids.origRemote.toStringWithPort(), response->ids.qname.toString(), QType(response->ids.qtype).toString(), RCode::to_s(response->getHeader()->rcode), response->getHeader()->id); + } + else { + infolog("Answer to %s for %s %s (%s) with id %u", response->ids.origRemote.toStringWithPort(), response->ids.qname.toString(), QType(response->ids.qtype).toString(), RCode::to_s(response->getHeader()->rcode), response->getHeader()->id); + } + } + } + else { + if (d_includeTimestamp) { + fprintf(filepointer.get(), "[%llu.%lu] Answer to %s for %s %s (%s) with id %u\n", static_cast(response->getQueryRealTime().tv_sec), static_cast(response->getQueryRealTime().tv_nsec), response->ids.origRemote.toStringWithPort().c_str(), response->ids.qname.toString().c_str(), QType(response->ids.qtype).toString().c_str(), RCode::to_s(response->getHeader()->rcode).c_str(), response->getHeader()->id); + } + else { + fprintf(filepointer.get(), "Answer to %s for %s %s (%s) with id %u\n", response->ids.origRemote.toStringWithPort().c_str(), response->ids.qname.toString().c_str(), QType(response->ids.qtype).toString().c_str(), RCode::to_s(response->getHeader()->rcode).c_str(), response->getHeader()->id); + } + } + return Action::None; + } + + [[nodiscard]] std::string toString() const override + { + if (!d_fname.empty()) { + return "log to " + d_fname; + } + return "log"; + } + + void reload() override + { + if (!reopenLogFile()) { + warnlog("Unable to open file '%s' for logging: %s", d_fname, stringerror()); + } + } + +private: + bool reopenLogFile() + { + // we are using a naked pointer here because we don't want fclose to be called + // with a nullptr, which would happen if we constructor a shared_ptr with fclose + // as a custom deleter and nullptr as a FILE* + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) + auto* nfp = fopen(d_fname.c_str(), d_append ? "a+" : "w"); + if (nfp == nullptr) { + /* don't fall on our sword when reopening */ + return false; + } + + auto filepointer = std::shared_ptr(nfp, fclose); + nfp = nullptr; + + if (!d_buffered) { + setbuf(filepointer.get(), nullptr); + } + + std::atomic_store_explicit(&d_fp, std::move(filepointer), std::memory_order_release); + return true; + } + + std::string d_fname; + std::shared_ptr d_fp{nullptr}; + bool d_verboseOnly{true}; + bool d_includeTimestamp{false}; + bool d_append{false}; + bool d_buffered{true}; +}; + +class SetDisableValidationAction : public DNSAction +{ +public: + // this action does not stop the processing + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) { + header.cd = true; + return true; + }); + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "set cd=1"; + } +}; + +class SetSkipCacheAction : public DNSAction +{ +public: + // this action does not stop the processing + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + dnsquestion->ids.skipCache = true; + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "skip cache"; + } +}; + +class SetSkipCacheResponseAction : public DNSResponseAction +{ +public: + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + response->ids.skipCache = true; + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "skip cache"; + } +}; + +class SetTempFailureCacheTTLAction : public DNSAction +{ +public: + // this action does not stop the processing + SetTempFailureCacheTTLAction(uint32_t ttl) : + d_ttl(ttl) + { + } + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + dnsquestion->ids.tempFailureTTL = d_ttl; + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "set tempfailure cache ttl to " + std::to_string(d_ttl); + } + +private: + uint32_t d_ttl; +}; + +class SetECSPrefixLengthAction : public DNSAction +{ +public: + // this action does not stop the processing + SetECSPrefixLengthAction(uint16_t v4Length, uint16_t v6Length) : + d_v4PrefixLength(v4Length), d_v6PrefixLength(v6Length) + { + } + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + dnsquestion->ecsPrefixLength = dnsquestion->ids.origRemote.sin4.sin_family == AF_INET ? d_v4PrefixLength : d_v6PrefixLength; + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "set ECS prefix length to " + std::to_string(d_v4PrefixLength) + "/" + std::to_string(d_v6PrefixLength); + } + +private: + uint16_t d_v4PrefixLength; + uint16_t d_v6PrefixLength; +}; + +class SetECSOverrideAction : public DNSAction +{ +public: + // this action does not stop the processing + SetECSOverrideAction(bool ecsOverride) : + d_ecsOverride(ecsOverride) + { + } + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + dnsquestion->ecsOverride = d_ecsOverride; + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "set ECS override to " + std::to_string(static_cast(d_ecsOverride)); + } + +private: + bool d_ecsOverride; +}; + +class SetDisableECSAction : public DNSAction +{ +public: + // this action does not stop the processing + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + dnsquestion->useECS = false; + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "disable ECS"; + } +}; + +class SetECSAction : public DNSAction +{ +public: + // this action does not stop the processing + SetECSAction(const Netmask& v4Netmask) : + d_v4(v4Netmask), d_hasV6(false) + { + } + + SetECSAction(const Netmask& v4Netmask, const Netmask& v6Netmask) : + d_v4(v4Netmask), d_v6(v6Netmask), d_hasV6(true) + { + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + if (d_hasV6) { + dnsquestion->ecs = std::make_unique(dnsquestion->ids.origRemote.isIPv4() ? d_v4 : d_v6); + } + else { + dnsquestion->ecs = std::make_unique(d_v4); + } + + return Action::None; + } + + [[nodiscard]] std::string toString() const override + { + std::string result = "set ECS to " + d_v4.toString(); + if (d_hasV6) { + result += " / " + d_v6.toString(); + } + return result; + } + +private: + Netmask d_v4; + Netmask d_v6; + bool d_hasV6; +}; + +#ifndef DISABLE_PROTOBUF +static DnstapMessage::ProtocolType ProtocolToDNSTap(dnsdist::Protocol protocol) +{ + if (protocol == dnsdist::Protocol::DoUDP) { + return DnstapMessage::ProtocolType::DoUDP; + } + if (protocol == dnsdist::Protocol::DoTCP) { + return DnstapMessage::ProtocolType::DoTCP; + } + if (protocol == dnsdist::Protocol::DoT) { + return DnstapMessage::ProtocolType::DoT; + } + if (protocol == dnsdist::Protocol::DoH || protocol == dnsdist::Protocol::DoH3) { + return DnstapMessage::ProtocolType::DoH; + } + if (protocol == dnsdist::Protocol::DNSCryptUDP) { + return DnstapMessage::ProtocolType::DNSCryptUDP; + } + if (protocol == dnsdist::Protocol::DNSCryptTCP) { + return DnstapMessage::ProtocolType::DNSCryptTCP; + } + if (protocol == dnsdist::Protocol::DoQ) { + return DnstapMessage::ProtocolType::DoQ; + } + throw std::runtime_error("Unhandled protocol for dnstap: " + protocol.toPrettyString()); +} + +static void remoteLoggerQueueData(RemoteLoggerInterface& remoteLogger, const std::string& data) +{ + auto ret = remoteLogger.queueData(data); + + switch (ret) { + case RemoteLoggerInterface::Result::Queued: + break; + case RemoteLoggerInterface::Result::PipeFull: { + vinfolog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret)); + break; + } + case RemoteLoggerInterface::Result::TooLarge: { + warnlog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret)); + break; + } + case RemoteLoggerInterface::Result::OtherError: + warnlog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret)); + } +} + +class DnstapLogAction : public DNSAction, public boost::noncopyable +{ +public: + // this action does not stop the processing + DnstapLogAction(std::string identity, std::shared_ptr& logger, std::optional> alterFunc) : + d_identity(std::move(identity)), d_logger(logger), d_alterFunc(std::move(alterFunc)) + { + } + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + static thread_local std::string data; + data.clear(); + + DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(dnsquestion->getProtocol()); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + DnstapMessage message(std::move(data), !dnsquestion->getHeader()->qr ? DnstapMessage::MessageType::client_query : DnstapMessage::MessageType::client_response, d_identity, &dnsquestion->ids.origRemote, &dnsquestion->ids.origDest, protocol, reinterpret_cast(dnsquestion->getData().data()), dnsquestion->getData().size(), &dnsquestion->getQueryRealTime(), nullptr); + { + if (d_alterFunc) { + auto lock = g_lua.lock(); + (*d_alterFunc)(dnsquestion, &message); + } + } + + data = message.getBuffer(); + remoteLoggerQueueData(*d_logger, data); + + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "remote log as dnstap to " + (d_logger ? d_logger->toString() : ""); + } + +private: + std::string d_identity; + std::shared_ptr d_logger; + std::optional> d_alterFunc; +}; + +namespace +{ + void addMetaDataToProtobuf(DNSDistProtoBufMessage& message, const DNSQuestion& dnsquestion, const std::vector>& metas) + { + for (const auto& [name, meta] : metas) { + message.addMeta(name, meta.getValues(dnsquestion), {}); + } + } + + void addTagsToProtobuf(DNSDistProtoBufMessage& message, const DNSQuestion& dnsquestion, const std::unordered_set& allowed) + { + if (!dnsquestion.ids.qTag) { + return; + } + + for (const auto& [key, value] : *dnsquestion.ids.qTag) { + if (!allowed.empty() && allowed.count(key) == 0) { + continue; + } + + if (value.empty()) { + message.addTag(key); + } + else { + auto tag = key; + tag.append(":"); + tag.append(value); + message.addTag(tag); + } + } + } + + void addExtendedDNSErrorToProtobuf(DNSDistProtoBufMessage& message, const DNSResponse& response, const std::string& metaKey) + { + auto [infoCode, extraText] = dnsdist::edns::getExtendedDNSError(response.getData()); + if (!infoCode) { + return; + } + + if (extraText) { + message.addMeta(metaKey, {*extraText}, {*infoCode}); + } + else { + message.addMeta(metaKey, {}, {*infoCode}); + } + } +} + +class RemoteLogAction : public DNSAction, public boost::noncopyable +{ +public: + // this action does not stop the processing + RemoteLogAction(RemoteLogActionConfiguration& config) : + d_tagsToExport(std::move(config.tagsToExport)), d_metas(std::move(config.metas)), d_logger(config.logger), d_alterFunc(std::move(config.alterQueryFunc)), d_serverID(config.serverID), d_ipEncryptKey(config.ipEncryptKey) + { + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + if (!dnsquestion->ids.d_protoBufData) { + dnsquestion->ids.d_protoBufData = std::make_unique(); + } + if (!dnsquestion->ids.d_protoBufData->uniqueId) { + dnsquestion->ids.d_protoBufData->uniqueId = getUniqueID(); + } + + DNSDistProtoBufMessage message(*dnsquestion); + if (!d_serverID.empty()) { + message.setServerIdentity(d_serverID); + } + +#ifdef HAVE_IPCIPHER + if (!d_ipEncryptKey.empty()) { + message.setRequestor(encryptCA(dnsquestion->ids.origRemote, d_ipEncryptKey)); + } +#endif /* HAVE_IPCIPHER */ + + if (d_tagsToExport) { + addTagsToProtobuf(message, *dnsquestion, *d_tagsToExport); + } + + addMetaDataToProtobuf(message, *dnsquestion, d_metas); + + if (d_alterFunc) { + auto lock = g_lua.lock(); + (*d_alterFunc)(dnsquestion, &message); + } + + static thread_local std::string data; + data.clear(); + message.serialize(data); + remoteLoggerQueueData(*d_logger, data); + + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "remote log to " + (d_logger ? d_logger->toString() : ""); + } + +private: + std::optional> d_tagsToExport; + std::vector> d_metas; + std::shared_ptr d_logger; + std::optional> d_alterFunc; + std::string d_serverID; + std::string d_ipEncryptKey; +}; + +#endif /* DISABLE_PROTOBUF */ + +class SNMPTrapAction : public DNSAction +{ +public: + // this action does not stop the processing + SNMPTrapAction(std::string reason) : + d_reason(std::move(reason)) + { + } + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + if (g_snmpAgent != nullptr && dnsdist::configuration::getImmutableConfiguration().d_snmpTrapsEnabled) { + g_snmpAgent->sendDNSTrap(*dnsquestion, d_reason); + } + + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "send SNMP trap"; + } + +private: + std::string d_reason; +}; + +class SetTagAction : public DNSAction +{ +public: + // this action does not stop the processing + SetTagAction(std::string tag, std::string value) : + d_tag(std::move(tag)), d_value(std::move(value)) + { + } + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + dnsquestion->setTag(d_tag, d_value); + + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "set tag '" + d_tag + "' to value '" + d_value + "'"; + } + +private: + std::string d_tag; + std::string d_value; +}; + +#ifndef DISABLE_PROTOBUF +class DnstapLogResponseAction : public DNSResponseAction, public boost::noncopyable +{ +public: + // this action does not stop the processing + DnstapLogResponseAction(std::string identity, std::shared_ptr& logger, std::optional> alterFunc) : + d_identity(std::move(identity)), d_logger(logger), d_alterFunc(std::move(alterFunc)) + { + } + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + static thread_local std::string data; + struct timespec now = {}; + gettime(&now, true); + data.clear(); + + DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(response->getProtocol()); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + DnstapMessage message(std::move(data), DnstapMessage::MessageType::client_response, d_identity, &response->ids.origRemote, &response->ids.origDest, protocol, reinterpret_cast(response->getData().data()), response->getData().size(), &response->getQueryRealTime(), &now); + { + if (d_alterFunc) { + auto lock = g_lua.lock(); + (*d_alterFunc)(response, &message); + } + } + + data = message.getBuffer(); + remoteLoggerQueueData(*d_logger, data); + + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "log response as dnstap to " + (d_logger ? d_logger->toString() : ""); + } + +private: + std::string d_identity; + std::shared_ptr d_logger; + std::optional> d_alterFunc; +}; + +class RemoteLogResponseAction : public DNSResponseAction, public boost::noncopyable +{ +public: + // this action does not stop the processing + RemoteLogResponseAction(RemoteLogActionConfiguration& config) : + d_tagsToExport(std::move(config.tagsToExport)), d_metas(std::move(config.metas)), d_logger(config.logger), d_alterFunc(std::move(config.alterResponseFunc)), d_serverID(config.serverID), d_ipEncryptKey(config.ipEncryptKey), d_exportExtendedErrorsToMeta(std::move(config.exportExtendedErrorsToMeta)), d_includeCNAME(config.includeCNAME) + { + } + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + if (!response->ids.d_protoBufData) { + response->ids.d_protoBufData = std::make_unique(); + } + if (!response->ids.d_protoBufData->uniqueId) { + response->ids.d_protoBufData->uniqueId = getUniqueID(); + } + + DNSDistProtoBufMessage message(*response, d_includeCNAME); + if (!d_serverID.empty()) { + message.setServerIdentity(d_serverID); + } + +#ifdef HAVE_IPCIPHER + if (!d_ipEncryptKey.empty()) { + message.setRequestor(encryptCA(response->ids.origRemote, d_ipEncryptKey)); + } +#endif /* HAVE_IPCIPHER */ + + if (d_tagsToExport) { + addTagsToProtobuf(message, *response, *d_tagsToExport); + } + + addMetaDataToProtobuf(message, *response, d_metas); + + if (d_exportExtendedErrorsToMeta) { + addExtendedDNSErrorToProtobuf(message, *response, *d_exportExtendedErrorsToMeta); + } + + if (d_alterFunc) { + auto lock = g_lua.lock(); + (*d_alterFunc)(response, &message); + } + + static thread_local std::string data; + data.clear(); + message.serialize(data); + d_logger->queueData(data); + + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "remote log response to " + (d_logger ? d_logger->toString() : ""); + } + +private: + std::optional> d_tagsToExport; + std::vector> d_metas; + std::shared_ptr d_logger; + std::optional> d_alterFunc; + std::string d_serverID; + std::string d_ipEncryptKey; + std::optional d_exportExtendedErrorsToMeta{std::nullopt}; + bool d_includeCNAME; +}; + +#endif /* DISABLE_PROTOBUF */ + +class DropResponseAction : public DNSResponseAction +{ +public: + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + return Action::Drop; + } + [[nodiscard]] std::string toString() const override + { + return "drop"; + } +}; + +class AllowResponseAction : public DNSResponseAction +{ +public: + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + return Action::Allow; + } + [[nodiscard]] std::string toString() const override + { + return "allow"; + } +}; + +class DelayResponseAction : public DNSResponseAction +{ +public: + DelayResponseAction(int msec) : + d_msec(msec) + { + } + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + *ruleresult = std::to_string(d_msec); + return Action::Delay; + } + [[nodiscard]] std::string toString() const override + { + return "delay by " + std::to_string(d_msec) + " ms"; + } + +private: + int d_msec; +}; + +class SNMPTrapResponseAction : public DNSResponseAction +{ +public: + // this action does not stop the processing + SNMPTrapResponseAction(std::string reason) : + d_reason(std::move(reason)) + { + } + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + if (g_snmpAgent != nullptr && dnsdist::configuration::getImmutableConfiguration().d_snmpTrapsEnabled) { + g_snmpAgent->sendDNSTrap(*response, d_reason); + } + + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "send SNMP trap"; + } + +private: + std::string d_reason; +}; + +class SetTagResponseAction : public DNSResponseAction +{ +public: + // this action does not stop the processing + SetTagResponseAction(std::string tag, std::string value) : + d_tag(std::move(tag)), d_value(std::move(value)) + { + } + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + response->setTag(d_tag, d_value); + + return Action::None; + } + [[nodiscard]] std::string toString() const override + { + return "set tag '" + d_tag + "' to value '" + d_value + "'"; + } + +private: + std::string d_tag; + std::string d_value; +}; + +class ClearRecordTypesResponseAction : public DNSResponseAction, public boost::noncopyable +{ +public: + ClearRecordTypesResponseAction(std::unordered_set qtypes) : + d_qtypes(std::move(qtypes)) + { + } + + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + if (!d_qtypes.empty()) { + clearDNSPacketRecordTypes(response->getMutableData(), d_qtypes); + } + return DNSResponseAction::Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "clear record types"; + } + +private: + std::unordered_set d_qtypes{}; +}; + +class ContinueAction : public DNSAction +{ +public: + // this action does not stop the processing + ContinueAction(std::shared_ptr& action) : + d_action(action) + { + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + if (d_action) { + /* call the action */ + auto action = (*d_action)(dnsquestion, ruleresult); + bool drop = false; + /* apply the changes if needed (pool selection, flags, etc */ + processRulesResult(action, *dnsquestion, *ruleresult, drop); + } + + /* but ignore the resulting action no matter what */ + return Action::None; + } + + [[nodiscard]] std::string toString() const override + { + if (d_action) { + return "continue after: " + (d_action ? d_action->toString() : ""); + } + return "no op"; + } + +private: + std::shared_ptr d_action; +}; + +#if defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_HTTP3) +class HTTPStatusAction : public DNSAction +{ +public: + HTTPStatusAction(int code, PacketBuffer body, std::string contentType, dnsdist::ResponseConfig responseConfig) : + d_responseConfig(responseConfig), d_body(std::move(body)), d_contentType(std::move(contentType)), d_code(code) + { + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { +#if defined(HAVE_DNS_OVER_HTTPS) + if (dnsquestion->ids.du) { + dnsquestion->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { + header.qr = true; // for good measure + setResponseHeadersFromConfig(header, d_responseConfig); + return true; + }); + return Action::HeaderModify; + } +#endif /* HAVE_DNS_OVER_HTTPS */ +#if defined(HAVE_DNS_OVER_HTTP3) + if (dnsquestion->ids.doh3u) { + dnsquestion->ids.doh3u->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { + header.qr = true; // for good measure + setResponseHeadersFromConfig(header, d_responseConfig); + return true; + }); + return Action::HeaderModify; + } +#endif /* HAVE_DNS_OVER_HTTP3 */ + return Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "return an HTTP status of " + std::to_string(d_code); + } + +private: + dnsdist::ResponseConfig d_responseConfig; + PacketBuffer d_body; + std::string d_contentType; + int d_code; +}; +#endif /* HAVE_DNS_OVER_HTTPS || HAVE_DNS_OVER_HTTP3 */ + +#if defined(HAVE_LMDB) || defined(HAVE_CDB) +class KeyValueStoreLookupAction : public DNSAction +{ +public: + // this action does not stop the processing + KeyValueStoreLookupAction(std::shared_ptr& kvs, std::shared_ptr& lookupKey, std::string destinationTag) : + d_kvs(kvs), d_key(lookupKey), d_tag(std::move(destinationTag)) + { + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + std::vector keys = d_key->getKeys(*dnsquestion); + std::string result; + for (const auto& key : keys) { + if (d_kvs->getValue(key, result)) { + break; + } + } + + dnsquestion->setTag(d_tag, std::move(result)); + + return Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "lookup key-value store based on '" + d_key->toString() + "' and set the result in tag '" + d_tag + "'"; + } + +private: + std::shared_ptr d_kvs; + std::shared_ptr d_key; + std::string d_tag; +}; + +class KeyValueStoreRangeLookupAction : public DNSAction +{ +public: + // this action does not stop the processing + KeyValueStoreRangeLookupAction(std::shared_ptr& kvs, std::shared_ptr& lookupKey, std::string destinationTag) : + d_kvs(kvs), d_key(lookupKey), d_tag(std::move(destinationTag)) + { + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + std::vector keys = d_key->getKeys(*dnsquestion); + std::string result; + for (const auto& key : keys) { + if (d_kvs->getRangeValue(key, result)) { + break; + } + } + + dnsquestion->setTag(d_tag, std::move(result)); + + return Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "do a range-based lookup in key-value store based on '" + d_key->toString() + "' and set the result in tag '" + d_tag + "'"; + } + +private: + std::shared_ptr d_kvs; + std::shared_ptr d_key; + std::string d_tag; +}; +#endif /* defined(HAVE_LMDB) || defined(HAVE_CDB) */ + +class SetMaxReturnedTTLAction : public DNSAction +{ +public: + SetMaxReturnedTTLAction(uint32_t cap) : + d_cap(cap) + { + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + dnsquestion->ids.ttlCap = d_cap; + return DNSAction::Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "cap the TTL of the returned response to " + std::to_string(d_cap); + } + +private: + uint32_t d_cap; +}; + +class SetMaxReturnedTTLResponseAction : public DNSResponseAction +{ +public: + SetMaxReturnedTTLResponseAction(uint32_t cap) : + d_cap(cap) + { + } + + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + response->ids.ttlCap = d_cap; + return DNSResponseAction::Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "cap the TTL of the returned response to " + std::to_string(d_cap); + } + +private: + uint32_t d_cap; +}; + +class NegativeAndSOAAction : public DNSAction +{ +public: + NegativeAndSOAAction(bool nxd, DNSName zone, uint32_t ttl, DNSName mname, DNSName rname, dnsdist::actions::SOAParams params, bool soaInAuthoritySection, dnsdist::ResponseConfig responseConfig) : + d_responseConfig(responseConfig), d_zone(std::move(zone)), d_mname(std::move(mname)), d_rname(std::move(rname)), d_ttl(ttl), d_params(params), d_nxd(nxd), d_soaInAuthoritySection(soaInAuthoritySection) + { + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + if (!setNegativeAndAdditionalSOA(*dnsquestion, d_nxd, d_zone, d_ttl, d_mname, d_rname, d_params.serial, d_params.refresh, d_params.retry, d_params.expire, d_params.minimum, d_soaInAuthoritySection)) { + return Action::None; + } + + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { + setResponseHeadersFromConfig(header, d_responseConfig); + return true; + }); + + return Action::Allow; + } + + [[nodiscard]] std::string toString() const override + { + return std::string(d_nxd ? "NXD" : "NODATA") + " with SOA"; + } + +private: + dnsdist::ResponseConfig d_responseConfig; + DNSName d_zone; + DNSName d_mname; + DNSName d_rname; + uint32_t d_ttl; + dnsdist::actions::SOAParams d_params; + bool d_nxd; + bool d_soaInAuthoritySection; +}; + +class SetProxyProtocolValuesAction : public DNSAction +{ +public: + // this action does not stop the processing + SetProxyProtocolValuesAction(const std::vector>& values) + { + d_values.reserve(values.size()); + for (const auto& value : values) { + d_values.push_back({value.second, value.first}); + } + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + if (!dnsquestion->proxyProtocolValues) { + dnsquestion->proxyProtocolValues = make_unique>(); + } + + *(dnsquestion->proxyProtocolValues) = d_values; + + return Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "set Proxy-Protocol values"; + } + +private: + std::vector d_values; +}; + +class SetAdditionalProxyProtocolValueAction : public DNSAction +{ +public: + // this action does not stop the processing + SetAdditionalProxyProtocolValueAction(uint8_t type, std::string value) : + d_value(std::move(value)), d_type(type) + { + } + + DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override + { + if (!dnsquestion->proxyProtocolValues) { + dnsquestion->proxyProtocolValues = make_unique>(); + } + + dnsquestion->proxyProtocolValues->push_back({d_value, d_type}); + + return Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "add a Proxy-Protocol value of type " + std::to_string(d_type); + } + +private: + std::string d_value; + uint8_t d_type; +}; + +class SetReducedTTLResponseAction : public DNSResponseAction, public boost::noncopyable +{ +public: + // this action does not stop the processing + SetReducedTTLResponseAction(uint8_t percentage) : + d_ratio(percentage / 100.0) + { + } + + DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override + { + // NOLINTNEXTLINE(bugprone-easily-swappable-parameters) + auto visitor = [&](uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl) { + return ttl * d_ratio; + }; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + editDNSPacketTTL(reinterpret_cast(response->getMutableData().data()), response->getData().size(), visitor); + return DNSResponseAction::Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "reduce ttl to " + std::to_string(d_ratio * 100) + " percent of its value"; + } + +private: + double d_ratio{1.0}; +}; + +class SetExtendedDNSErrorAction : public DNSAction +{ +public: + // this action does not stop the processing + SetExtendedDNSErrorAction(uint16_t infoCode, const std::string& extraText) + { + d_ede.infoCode = infoCode; + d_ede.extraText = extraText; + } + + DNSAction::Action operator()(DNSQuestion* dnsQuestion, std::string* ruleresult) const override + { + dnsQuestion->ids.d_extendedError = std::make_unique(d_ede); + + return DNSAction::Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "set EDNS Extended DNS Error to " + std::to_string(d_ede.infoCode) + (d_ede.extraText.empty() ? std::string() : std::string(": \"") + d_ede.extraText + std::string("\"")); + } + +private: + EDNSExtendedError d_ede; +}; + +class SetExtendedDNSErrorResponseAction : public DNSResponseAction +{ +public: + // this action does not stop the processing + SetExtendedDNSErrorResponseAction(uint16_t infoCode, const std::string& extraText) + { + d_ede.infoCode = infoCode; + d_ede.extraText = extraText; + } + + DNSResponseAction::Action operator()(DNSResponse* dnsResponse, std::string* ruleresult) const override + { + dnsResponse->ids.d_extendedError = std::make_unique(d_ede); + + return DNSResponseAction::Action::None; + } + + [[nodiscard]] std::string toString() const override + { + return "set EDNS Extended DNS Error to " + std::to_string(d_ede.infoCode) + (d_ede.extraText.empty() ? std::string() : std::string(": \"") + d_ede.extraText + std::string("\"")); + } + +private: + EDNSExtendedError d_ede; +}; + +class LimitTTLResponseAction : public DNSResponseAction, public boost::noncopyable +{ +public: + LimitTTLResponseAction() {} + + LimitTTLResponseAction(uint32_t min, uint32_t max = std::numeric_limits::max(), const std::unordered_set& types = {}) : + d_types(types), d_min(min), d_max(max) + { + } + + DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + { + dnsdist::PacketMangling::restrictDNSPacketTTLs(dr->getMutableData(), d_min, d_max, d_types); + return DNSResponseAction::Action::None; + } + + std::string toString() const override + { + std::string result = "limit ttl (" + std::to_string(d_min) + " <= ttl <= " + std::to_string(d_max); + if (!d_types.empty()) { + bool first = true; + result += ", types in ["; + for (const auto& type : d_types) { + if (first) { + first = false; + } + else { + result += " "; + } + result += type.toString(); + } + result += "]"; + } + result += +")"; + return result; + } + +private: + std::unordered_set d_types; + uint32_t d_min{0}; + uint32_t d_max{std::numeric_limits::max()}; +}; + +#ifndef DISABLE_PROTOBUF +std::shared_ptr getRemoteLogAction(RemoteLogActionConfiguration& config) +{ + return std::shared_ptr(new RemoteLogAction(config)); +} + +std::shared_ptr getRemoteLogResponseAction(RemoteLogActionConfiguration& config) +{ + return std::shared_ptr(new RemoteLogResponseAction(config)); +} + +std::shared_ptr getDnstapLogAction(const std::string& identity, std::shared_ptr logger, std::optional alterFunc) +{ + return std::shared_ptr(new DnstapLogAction(identity, logger, alterFunc)); +} + +std::shared_ptr getDnstapLogResponseAction(const std::string& identity, std::shared_ptr logger, std::optional alterFunc) +{ + return std::shared_ptr(new DnstapLogResponseAction(identity, logger, alterFunc)); +} +#endif /* DISABLE_PROTOBUF */ + +#if defined(HAVE_LMDB) || defined(HAVE_CDB) +std::shared_ptr getKeyValueStoreLookupAction(std::shared_ptr& kvs, std::shared_ptr& lookupKey, const std::string& destinationTag) +{ + return std::shared_ptr(new KeyValueStoreLookupAction(kvs, lookupKey, destinationTag)); +} + +std::shared_ptr getKeyValueStoreRangeLookupAction(std::shared_ptr& kvs, std::shared_ptr& lookupKey, const std::string& destinationTag) +{ + return std::shared_ptr(new KeyValueStoreRangeLookupAction(kvs, lookupKey, destinationTag)); +} +#endif /* defined(HAVE_LMDB) || defined(HAVE_CDB) */ + +#ifdef HAVE_DNS_OVER_HTTPS +std::shared_ptr getHTTPStatusAction(uint16_t status, PacketBuffer&& body, const std::string& contentType, const dnsdist::ResponseConfig& responseConfig) +{ + return std::shared_ptr(new HTTPStatusAction(status, body, contentType, responseConfig)); +} + +#endif + +std::shared_ptr getLimitTTLResponseAction(uint32_t min, uint32_t max, std::unordered_set types) +{ + return std::shared_ptr(new LimitTTLResponseAction(min, max, types)); +} + +std::shared_ptr getMinTTLResponseAction(uint32_t min) +{ + return std::shared_ptr(new LimitTTLResponseAction(min)); +} + +std::shared_ptr getClearRecordTypesResponseAction(std::unordered_set types) +{ + return std::shared_ptr(new ClearRecordTypesResponseAction(types)); +} + +std::shared_ptr getContinueAction(std::shared_ptr action) +{ + return std::shared_ptr(new ContinueAction(action)); +} + +std::shared_ptr getNegativeAndSOAAction(bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, const SOAParams& params, bool soaInAuthority, dnsdist::ResponseConfig responseConfig) +{ + return std::shared_ptr(new NegativeAndSOAAction(nxd, zone, ttl, mname, rname, params, soaInAuthority, responseConfig)); +} + +std::shared_ptr getRCodeAction(uint8_t rcode, const dnsdist::ResponseConfig& responseConfig) +{ + return std::shared_ptr(new RCodeAction(rcode, responseConfig)); +} + +std::shared_ptr getERCodeAction(uint8_t rcode, const dnsdist::ResponseConfig& responseConfig) +{ + return std::shared_ptr(new ERCodeAction(rcode, responseConfig)); +} + +std::shared_ptr getSetECSAction(const std::string& ipv4) +{ + return std::shared_ptr(new SetECSAction(Netmask(ipv4))); +} + +std::shared_ptr getSetECSAction(const std::string& ipv4, const std::string& ipv6) +{ + return std::shared_ptr(new SetECSAction(Netmask(ipv4), Netmask(ipv6))); +} + +std::shared_ptr getSpoofAction(const std::vector& addresses, const dnsdist::ResponseConfig& config) +{ + return std::shared_ptr(new SpoofAction(addresses, config)); +} + +std::shared_ptr getSpoofAction(const std::vector& rawRDatas, std::optional qtypeForAny, const dnsdist::ResponseConfig& config) +{ + return std::shared_ptr(new SpoofAction(rawRDatas, qtypeForAny, config)); +} + +std::shared_ptr getSpoofAction(const DNSName& cname, const dnsdist::ResponseConfig& config) +{ + return std::shared_ptr(new SpoofAction(cname, config)); +} + +std::shared_ptr getSpoofAction(const PacketBuffer& packet) +{ + return std::shared_ptr(new SpoofAction(packet)); +} + +std::shared_ptr getSpoofSVCAction(const std::vector& parameters, const dnsdist::ResponseConfig& responseConfig) +{ + return std::shared_ptr(new SpoofSVCAction(parameters, responseConfig)); +} + +std::shared_ptr getSetMaxReturnedTTLAction(uint32_t max) +{ + return std::shared_ptr(new SetMaxReturnedTTLAction(max)); +} + +std::shared_ptr getSetMaxReturnedTTLResponseAction(uint32_t max) +{ + return std::shared_ptr(new SetMaxReturnedTTLResponseAction(max)); +} + +std::shared_ptr getSetMaxTTLResponseAction(uint32_t max) +{ + return std::shared_ptr(new LimitTTLResponseAction(0, max)); +} + +std::shared_ptr getSetProxyProtocolValuesAction(const std::vector>& values) +{ + return std::shared_ptr(new SetProxyProtocolValuesAction(values)); +} + +std::shared_ptr getTeeAction(const ComboAddress& rca, std::optional lca, bool addECS, bool addProxyProtocol) +{ + return std::shared_ptr(new TeeAction(rca, lca, addECS, addProxyProtocol)); +} + +#include "dnsdist-actions-factory-generated.cc" +#include "dnsdist-response-actions-factory-generated.cc" +} diff --git a/pdns/dnsdistdist/dnsdist-actions-factory.hh b/pdns/dnsdistdist/dnsdist-actions-factory.hh new file mode 100644 index 0000000000..4d2cd45377 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-actions-factory.hh @@ -0,0 +1,121 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +#include +#include +#include +#include +#include + +struct DNSQuestion; +struct DNSResponse; + +#include "dnsdist-actions.hh" +#include "dnsdist-protobuf.hh" +#include "dnsdist-svc.hh" +#include "dnstap.hh" +#include "iputils.hh" +#include "noinitvector.hh" + +struct dnsdist_ffi_dnsquestion_t; +struct dnsdist_ffi_dnsresponse_t; +class RemoteLoggerInterface; +class KeyValueStore; +class KeyValueLookupKey; + +namespace dnsdist::actions +{ +using LuaActionFunction = std::function>(DNSQuestion* dnsquestion)>; +using LuaResponseActionFunction = std::function>(DNSResponse* response)>; +using LuaActionFFIFunction = std::function; +using LuaResponseActionFFIFunction = std::function; + +struct SOAParams +{ + uint32_t serial; + uint32_t refresh; + uint32_t retry; + uint32_t expire; + uint32_t minimum; +}; + +#include "dnsdist-actions-factory-generated.hh" +#include "dnsdist-response-actions-factory-generated.hh" + +std::shared_ptr getContinueAction(std::shared_ptr action); +#ifdef HAVE_DNS_OVER_HTTPS +std::shared_ptr getHTTPStatusAction(uint16_t status, PacketBuffer&& body, const std::string& contentType, const dnsdist::ResponseConfig& responseConfig); +#endif +std::shared_ptr getNegativeAndSOAAction(bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, const SOAParams& params, bool soaInAuthority, dnsdist::ResponseConfig responseConfig); +std::shared_ptr getSetProxyProtocolValuesAction(const std::vector>& values); +std::shared_ptr getRCodeAction(uint8_t rcode, const dnsdist::ResponseConfig& responseConfig); +std::shared_ptr getERCodeAction(uint8_t rcode, const dnsdist::ResponseConfig& responseConfig); + +#if defined(HAVE_LMDB) || defined(HAVE_CDB) +std::shared_ptr getKeyValueStoreLookupAction(std::shared_ptr& kvs, std::shared_ptr& lookupKey, const std::string& destinationTag); +std::shared_ptr getKeyValueStoreRangeLookupAction(std::shared_ptr& kvs, std::shared_ptr& lookupKey, const std::string& destinationTag); +#endif /* defined(HAVE_LMDB) || defined(HAVE_CDB) */ + +std::shared_ptr getSetECSAction(const std::string& ipv4); +std::shared_ptr getSetECSAction(const std::string& ipv4, const std::string& ipv6); +std::shared_ptr getSpoofAction(const std::vector& addresses, const dnsdist::ResponseConfig& config); +std::shared_ptr getSpoofAction(const std::vector& rawRDatas, std::optional qtypeForAny, const dnsdist::ResponseConfig& config); +std::shared_ptr getSpoofAction(const DNSName& cname, const dnsdist::ResponseConfig& config); +std::shared_ptr getSpoofAction(const PacketBuffer& packet); + +std::shared_ptr getSpoofSVCAction(const std::vector& parameters, const dnsdist::ResponseConfig& responseConfig); + +std::shared_ptr getSetMaxReturnedTTLAction(uint32_t max); +std::shared_ptr getLimitTTLResponseAction(uint32_t min, uint32_t max = std::numeric_limits::max(), std::unordered_set types = {}); +std::shared_ptr getMinTTLResponseAction(uint32_t min); +std::shared_ptr getSetMaxReturnedTTLResponseAction(uint32_t max); +std::shared_ptr getSetMaxTTLResponseAction(uint32_t max); + +std::shared_ptr getClearRecordTypesResponseAction(std::unordered_set types); + +std::shared_ptr getTeeAction(const ComboAddress& rca, std::optional lca, bool addECS, bool addProxyProtocol); + +#ifndef DISABLE_PROTOBUF +using ProtobufAlterFunction = std::function; +using ProtobufAlterResponseFunction = std::function; +using DnstapAlterFunction = std::function; +using DnstapAlterResponseFunction = std::function; + +struct RemoteLogActionConfiguration +{ + std::vector> metas; + std::optional> tagsToExport{std::nullopt}; + std::optional alterQueryFunc; + std::optional alterResponseFunc; + std::shared_ptr logger; + std::string serverID; + std::string ipEncryptKey; + std::optional exportExtendedErrorsToMeta{std::nullopt}; + bool includeCNAME{false}; +}; +std::shared_ptr getRemoteLogAction(RemoteLogActionConfiguration& config); +std::shared_ptr getRemoteLogResponseAction(RemoteLogActionConfiguration& config); +std::shared_ptr getDnstapLogAction(const std::string& identity, std::shared_ptr logger, std::optional alterFunc); +std::shared_ptr getDnstapLogResponseAction(const std::string& identity, std::shared_ptr logger, std::optional alterFunc); +#endif /* DISABLE_PROTOBUF */ +} diff --git a/pdns/dnsdistdist/dnsdist-lua-actions.cc b/pdns/dnsdistdist/dnsdist-lua-actions.cc index 24c13b4482..f35246ceba 100644 --- a/pdns/dnsdistdist/dnsdist-lua-actions.cc +++ b/pdns/dnsdistdist/dnsdist-lua-actions.cc @@ -20,2318 +20,15 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ #include "config.h" -#include "threadname.hh" #include "dnsdist.hh" -#include "dnsdist-async.hh" +#include "dnsdist-actions-factory.hh" #include "dnsdist-dnsparser.hh" -#include "dnsdist-ecs.hh" -#include "dnsdist-edns.hh" #include "dnsdist-lua.hh" #include "dnsdist-lua-ffi.hh" -#include "dnsdist-mac-address.hh" #include "dnsdist-protobuf.hh" -#include "dnsdist-proxy-protocol.hh" -#include "dnsdist-kvs.hh" #include "dnsdist-rule-chains.hh" -#include "dnsdist-self-answers.hh" -#include "dnsdist-snmp.hh" -#include "dnsdist-svc.hh" - -#include "dnstap.hh" -#include "dnswriter.hh" -#include "ednsoptions.hh" -#include "fstrm_logger.hh" -#include "remote_logger.hh" -#include "svc-records.hh" - -#include - -#include "ipcipher.hh" - -class DropAction : public DNSAction -{ -public: - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - return Action::Drop; - } - [[nodiscard]] std::string toString() const override - { - return "drop"; - } -}; - -class AllowAction : public DNSAction -{ -public: - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - return Action::Allow; - } - [[nodiscard]] std::string toString() const override - { - return "allow"; - } -}; - -class NoneAction : public DNSAction -{ -public: - // this action does not stop the processing - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "no op"; - } -}; - -class QPSAction : public DNSAction -{ -public: - QPSAction(int limit) : - d_qps(QPSLimiter(limit, limit)) - { - } - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - if (d_qps.lock()->check()) { - return Action::None; - } - return Action::Drop; - } - [[nodiscard]] std::string toString() const override - { - return "qps limit to " + std::to_string(d_qps.lock()->getRate()); - } - -private: - mutable LockGuarded d_qps; -}; - -class DelayAction : public DNSAction -{ -public: - DelayAction(int msec) : - d_msec(msec) - { - } - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - *ruleresult = std::to_string(d_msec); - return Action::Delay; - } - [[nodiscard]] std::string toString() const override - { - return "delay by " + std::to_string(d_msec) + " ms"; - } - -private: - int d_msec; -}; - -class LimitTTLResponseAction : public DNSResponseAction, public boost::noncopyable -{ -public: - LimitTTLResponseAction() {} - - LimitTTLResponseAction(uint32_t min, uint32_t max = std::numeric_limits::max(), const std::unordered_set& types = {}) : - d_types(types), d_min(min), d_max(max) - { - } - - DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override - { - dnsdist::PacketMangling::restrictDNSPacketTTLs(dr->getMutableData(), d_min, d_max, d_types); - return DNSResponseAction::Action::None; - } - - std::string toString() const override - { - std::string result = "limit ttl (" + std::to_string(d_min) + " <= ttl <= " + std::to_string(d_max); - if (!d_types.empty()) { - bool first = true; - result += ", types in ["; - for (const auto& type : d_types) { - if (first) { - first = false; - } - else { - result += " "; - } - result += type.toString(); - } - result += "]"; - } - result += +")"; - return result; - } - -private: - std::unordered_set d_types; - uint32_t d_min{0}; - uint32_t d_max{std::numeric_limits::max()}; -}; - -class TeeAction : public DNSAction -{ -public: - // this action does not stop the processing - TeeAction(const ComboAddress& rca, const boost::optional& lca, bool addECS = false, bool addProxyProtocol = false); - TeeAction(TeeAction& other) = delete; - TeeAction(TeeAction&& other) = delete; - TeeAction& operator=(TeeAction& other) = delete; - TeeAction& operator=(TeeAction&& other) = delete; - ~TeeAction() override; - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override; - [[nodiscard]] std::string toString() const override; - std::map getStats() const override; - -private: - void worker(); - - ComboAddress d_remote; - std::thread d_worker; - Socket d_socket; - mutable std::atomic d_senderrors{0}; - unsigned long d_recverrors{0}; - mutable std::atomic d_queries{0}; - stat_t d_responses{0}; - stat_t d_nxdomains{0}; - stat_t d_servfails{0}; - stat_t d_refuseds{0}; - stat_t d_formerrs{0}; - stat_t d_notimps{0}; - stat_t d_noerrors{0}; - mutable stat_t d_tcpdrops{0}; - stat_t d_otherrcode{0}; - std::atomic d_pleaseQuit{false}; - bool d_addECS{false}; - bool d_addProxyProtocol{false}; -}; - -TeeAction::TeeAction(const ComboAddress& rca, const boost::optional& lca, bool addECS, bool addProxyProtocol) : - d_remote(rca), d_socket(d_remote.sin4.sin_family, SOCK_DGRAM, 0), d_addECS(addECS), d_addProxyProtocol(addProxyProtocol) -{ - if (lca) { - d_socket.bind(*lca, false); - } - d_socket.connect(d_remote); - d_socket.setNonBlocking(); - d_worker = std::thread([this]() { - worker(); - }); -} - -TeeAction::~TeeAction() -{ - d_pleaseQuit = true; - close(d_socket.releaseHandle()); - d_worker.join(); -} - -DNSAction::Action TeeAction::operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const -{ - if (dnsquestion->overTCP()) { - d_tcpdrops++; - return DNSAction::Action::None; - } - - d_queries++; - - PacketBuffer query; - if (d_addECS) { - query = dnsquestion->getData(); - bool ednsAdded = false; - bool ecsAdded = false; - - std::string newECSOption; - generateECSOption(dnsquestion->ecs ? dnsquestion->ecs->getNetwork() : dnsquestion->ids.origRemote, newECSOption, dnsquestion->ecs ? dnsquestion->ecs->getBits() : dnsquestion->ecsPrefixLength); - - if (!handleEDNSClientSubnet(query, dnsquestion->getMaximumSize(), dnsquestion->ids.qname.wirelength(), ednsAdded, ecsAdded, dnsquestion->ecsOverride, newECSOption)) { - return DNSAction::Action::None; - } - } - - if (d_addProxyProtocol) { - auto proxyPayload = getProxyProtocolPayload(*dnsquestion); - if (query.empty()) { - query = dnsquestion->getData(); - } - if (!addProxyProtocol(query, proxyPayload)) { - return DNSAction::Action::None; - } - } - - { - const PacketBuffer& payload = query.empty() ? dnsquestion->getData() : query; - auto res = send(d_socket.getHandle(), payload.data(), payload.size(), 0); - - if (res <= 0) { - d_senderrors++; - } - } - - return DNSAction::Action::None; -} - -std::string TeeAction::toString() const -{ - return "tee to " + d_remote.toStringWithPort(); -} - -std::map TeeAction::getStats() const -{ - return {{"queries", d_queries}, - {"responses", d_responses}, - {"recv-errors", d_recverrors}, - {"send-errors", d_senderrors}, - {"noerrors", d_noerrors}, - {"nxdomains", d_nxdomains}, - {"refuseds", d_refuseds}, - {"servfails", d_servfails}, - {"other-rcode", d_otherrcode}, - {"tcp-drops", d_tcpdrops}}; -} - -void TeeAction::worker() -{ - setThreadName("dnsdist/TeeWork"); - std::array packet{}; - ssize_t res = 0; - const dnsheader_aligned dnsheader(packet.data()); - for (;;) { - res = waitForData(d_socket.getHandle(), 0, 250000); - if (d_pleaseQuit) { - break; - } - - if (res < 0) { - usleep(250000); - continue; - } - if (res == 0) { - continue; - } - res = recv(d_socket.getHandle(), packet.data(), packet.size(), 0); - if (static_cast(res) <= sizeof(struct dnsheader)) { - d_recverrors++; - } - else { - d_responses++; - } - - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well - if (dnsheader->rcode == RCode::NoError) { - d_noerrors++; - } - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well - else if (dnsheader->rcode == RCode::ServFail) { - d_servfails++; - } - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well - else if (dnsheader->rcode == RCode::NXDomain) { - d_nxdomains++; - } - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well - else if (dnsheader->rcode == RCode::Refused) { - d_refuseds++; - } - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well - else if (dnsheader->rcode == RCode::FormErr) { - d_formerrs++; - } - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions): rcode is unsigned, RCode::rcodes_ as well - else if (dnsheader->rcode == RCode::NotImp) { - d_notimps++; - } - } -} - -class PoolAction : public DNSAction -{ -public: - PoolAction(std::string pool, bool stopProcessing) : - d_pool(std::move(pool)), d_stopProcessing(stopProcessing) {} - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - if (d_stopProcessing) { - /* we need to do it that way to keep compatiblity with custom Lua actions returning DNSAction.Pool, 'poolname' */ - *ruleresult = d_pool; - return Action::Pool; - } - dnsquestion->ids.poolName = d_pool; - return Action::None; - } - - [[nodiscard]] std::string toString() const override - { - return "to pool " + d_pool; - } - -private: - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const std::string d_pool; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const bool d_stopProcessing; -}; - -class QPSPoolAction : public DNSAction -{ -public: - QPSPoolAction(unsigned int limit, std::string pool, bool stopProcessing) : - d_qps(QPSLimiter(limit, limit)), d_pool(std::move(pool)), d_stopProcessing(stopProcessing) {} - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - if (d_qps.lock()->check()) { - if (d_stopProcessing) { - /* we need to do it that way to keep compatiblity with custom Lua actions returning DNSAction.Pool, 'poolname' */ - *ruleresult = d_pool; - return Action::Pool; - } - dnsquestion->ids.poolName = d_pool; - } - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "max " + std::to_string(d_qps.lock()->getRate()) + " to pool " + d_pool; - } - -private: - mutable LockGuarded d_qps; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const std::string d_pool; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const bool d_stopProcessing; -}; - -class RCodeAction : public DNSAction -{ -public: - RCodeAction(uint8_t rcode, dnsdist::ResponseConfig responseConfig) : - d_responseConfig(responseConfig), d_rcode(rcode) {} - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { - header.rcode = d_rcode; - header.qr = true; // for good measure - setResponseHeadersFromConfig(header, d_responseConfig); - return true; - }); - return Action::HeaderModify; - } - [[nodiscard]] std::string toString() const override - { - return "set rcode " + std::to_string(d_rcode); - } - -private: - dnsdist::ResponseConfig d_responseConfig; - uint8_t d_rcode; -}; - -class ERCodeAction : public DNSAction -{ -public: - ERCodeAction(uint8_t rcode, dnsdist::ResponseConfig responseConfig) : - d_responseConfig(responseConfig), d_rcode(rcode) {} - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { - header.rcode = (d_rcode & 0xF); - header.qr = true; // for good measure - setResponseHeadersFromConfig(header, d_responseConfig); - return true; - }); - dnsquestion->ednsRCode = ((d_rcode & 0xFFF0) >> 4); - return Action::HeaderModify; - } - [[nodiscard]] std::string toString() const override - { - return "set ercode " + ERCode::to_s(d_rcode); - } - -private: - dnsdist::ResponseConfig d_responseConfig; - uint8_t d_rcode; -}; - -class SpoofSVCAction : public DNSAction -{ -public: - SpoofSVCAction(const LuaArray& parameters, dnsdist::ResponseConfig responseConfig) : - d_responseConfig(responseConfig) - { - d_payloads.reserve(parameters.size()); - - for (const auto& param : parameters) { - std::vector payload; - if (!generateSVCPayload(payload, param.second)) { - throw std::runtime_error("Unable to generate a valid SVC record from the supplied parameters"); - } - - d_payloads.push_back(std::move(payload)); - - for (const auto& hint : param.second.ipv4hints) { - d_additionals4.insert({param.second.target, ComboAddress(hint)}); - } - - for (const auto& hint : param.second.ipv6hints) { - d_additionals6.insert({param.second.target, ComboAddress(hint)}); - } - } - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - if (!dnsdist::svc::generateSVCResponse(*dnsquestion, d_payloads, d_additionals4, d_additionals6, d_responseConfig)) { - return Action::None; - } - - return Action::HeaderModify; - } - - [[nodiscard]] std::string toString() const override - { - return "spoof SVC record "; - } - -private: - dnsdist::ResponseConfig d_responseConfig; - std::vector> d_payloads{}; - std::set> d_additionals4{}; - std::set> d_additionals6{}; -}; - -class TCAction : public DNSAction -{ -public: - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - return Action::Truncate; - } - [[nodiscard]] std::string toString() const override - { - return "tc=1 answer"; - } -}; - -class TCResponseAction : public DNSResponseAction -{ -public: - DNSResponseAction::Action operator()(DNSResponse* dnsResponse, std::string* ruleresult) const override - { - return Action::Truncate; - } - [[nodiscard]] std::string toString() const override - { - return "tc=1 answer"; - } -}; - -class LuaAction : public DNSAction -{ -public: - using func_t = std::function>(DNSQuestion* dnsquestion)>; - LuaAction(LuaAction::func_t func) : - d_func(std::move(func)) - {} - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - try { - DNSAction::Action result{}; - { - auto lock = g_lua.lock(); - auto ret = d_func(dnsquestion); - if (ruleresult != nullptr) { - if (boost::optional rule = std::get<1>(ret)) { - *ruleresult = *rule; - } - else { - // default to empty string - ruleresult->clear(); - } - } - result = static_cast(std::get<0>(ret)); - } - dnsdist::handleQueuedAsynchronousEvents(); - return result; - } - catch (const std::exception& e) { - warnlog("LuaAction failed inside Lua, returning ServFail: %s", e.what()); - } - catch (...) { - warnlog("LuaAction failed inside Lua, returning ServFail: [unknown exception]"); - } - return DNSAction::Action::ServFail; - } - - [[nodiscard]] std::string toString() const override - { - return "Lua script"; - } - -private: - func_t d_func; -}; - -class LuaResponseAction : public DNSResponseAction -{ -public: - using func_t = std::function>(DNSResponse* response)>; - LuaResponseAction(LuaResponseAction::func_t func) : - d_func(std::move(func)) - {} - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - try { - DNSResponseAction::Action result{}; - { - auto lock = g_lua.lock(); - auto ret = d_func(response); - if (ruleresult != nullptr) { - if (boost::optional rule = std::get<1>(ret)) { - *ruleresult = *rule; - } - else { - // default to empty string - ruleresult->clear(); - } - } - result = static_cast(std::get<0>(ret)); - } - dnsdist::handleQueuedAsynchronousEvents(); - return result; - } - catch (const std::exception& e) { - warnlog("LuaResponseAction failed inside Lua, returning ServFail: %s", e.what()); - } - catch (...) { - warnlog("LuaResponseAction failed inside Lua, returning ServFail: [unknown exception]"); - } - return DNSResponseAction::Action::ServFail; - } - - [[nodiscard]] std::string toString() const override - { - return "Lua response script"; - } - -private: - func_t d_func; -}; - -class LuaFFIAction : public DNSAction -{ -public: - using func_t = std::function; - - LuaFFIAction(LuaFFIAction::func_t func) : - d_func(std::move(func)) - { - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - dnsdist_ffi_dnsquestion_t dqffi(dnsquestion); - try { - DNSAction::Action result{}; - { - auto lock = g_lua.lock(); - auto ret = d_func(&dqffi); - if (ruleresult != nullptr) { - if (dqffi.result) { - *ruleresult = *dqffi.result; - } - else { - // default to empty string - ruleresult->clear(); - } - } - result = static_cast(ret); - } - dnsdist::handleQueuedAsynchronousEvents(); - return result; - } - catch (const std::exception& e) { - warnlog("LuaFFIAction failed inside Lua, returning ServFail: %s", e.what()); - } - catch (...) { - warnlog("LuaFFIAction failed inside Lua, returning ServFail: [unknown exception]"); - } - return DNSAction::Action::ServFail; - } - - [[nodiscard]] std::string toString() const override - { - return "Lua FFI script"; - } - -private: - func_t d_func; -}; - -class LuaFFIPerThreadAction : public DNSAction -{ -public: - using func_t = std::function; - - LuaFFIPerThreadAction(std::string code) : - d_functionCode(std::move(code)), d_functionID(s_functionsCounter++) - { - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - try { - auto& state = t_perThreadStates[d_functionID]; - if (!state.d_initialized) { - setupLuaFFIPerThreadContext(state.d_luaContext); - /* mark the state as initialized first so if there is a syntax error - we only try to execute the code once */ - state.d_initialized = true; - state.d_func = state.d_luaContext.executeCode(d_functionCode); - } - - if (!state.d_func) { - /* the function was not properly initialized */ - return DNSAction::Action::None; - } - - dnsdist_ffi_dnsquestion_t dqffi(dnsquestion); - auto ret = state.d_func(&dqffi); - if (ruleresult != nullptr) { - if (dqffi.result) { - *ruleresult = *dqffi.result; - } - else { - // default to empty string - ruleresult->clear(); - } - } - dnsdist::handleQueuedAsynchronousEvents(); - return static_cast(ret); - } - catch (const std::exception& e) { - warnlog("LuaFFIPerThreadAction failed inside Lua, returning ServFail: %s", e.what()); - } - catch (...) { - warnlog("LuaFFIPerthreadAction failed inside Lua, returning ServFail: [unknown exception]"); - } - return DNSAction::Action::ServFail; - } - - [[nodiscard]] std::string toString() const override - { - return "Lua FFI per-thread script"; - } - -private: - struct PerThreadState - { - LuaContext d_luaContext; - func_t d_func; - bool d_initialized{false}; - }; - static std::atomic s_functionsCounter; - static thread_local std::map t_perThreadStates; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const std::string d_functionCode; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const uint64_t d_functionID; -}; - -std::atomic LuaFFIPerThreadAction::s_functionsCounter = 0; -thread_local std::map LuaFFIPerThreadAction::t_perThreadStates; - -class LuaFFIResponseAction : public DNSResponseAction -{ -public: - using func_t = std::function; - - LuaFFIResponseAction(LuaFFIResponseAction::func_t func) : - d_func(std::move(func)) - { - } - - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - dnsdist_ffi_dnsresponse_t ffiResponse(response); - try { - DNSResponseAction::Action result{}; - { - auto lock = g_lua.lock(); - auto ret = d_func(&ffiResponse); - if (ruleresult != nullptr) { - if (ffiResponse.result) { - *ruleresult = *ffiResponse.result; - } - else { - // default to empty string - ruleresult->clear(); - } - } - result = static_cast(ret); - } - dnsdist::handleQueuedAsynchronousEvents(); - return result; - } - catch (const std::exception& e) { - warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: %s", e.what()); - } - catch (...) { - warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: [unknown exception]"); - } - return DNSResponseAction::Action::ServFail; - } - - [[nodiscard]] std::string toString() const override - { - return "Lua FFI script"; - } - -private: - func_t d_func; -}; - -class LuaFFIPerThreadResponseAction : public DNSResponseAction -{ -public: - using func_t = std::function; - - LuaFFIPerThreadResponseAction(std::string code) : - d_functionCode(std::move(code)), d_functionID(s_functionsCounter++) - { - } - - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - try { - auto& state = t_perThreadStates[d_functionID]; - if (!state.d_initialized) { - setupLuaFFIPerThreadContext(state.d_luaContext); - /* mark the state as initialized first so if there is a syntax error - we only try to execute the code once */ - state.d_initialized = true; - state.d_func = state.d_luaContext.executeCode(d_functionCode); - } - - if (!state.d_func) { - /* the function was not properly initialized */ - return DNSResponseAction::Action::None; - } - - dnsdist_ffi_dnsresponse_t ffiResponse(response); - auto ret = state.d_func(&ffiResponse); - if (ruleresult != nullptr) { - if (ffiResponse.result) { - *ruleresult = *ffiResponse.result; - } - else { - // default to empty string - ruleresult->clear(); - } - } - dnsdist::handleQueuedAsynchronousEvents(); - return static_cast(ret); - } - catch (const std::exception& e) { - warnlog("LuaFFIPerThreadResponseAction failed inside Lua, returning ServFail: %s", e.what()); - } - catch (...) { - warnlog("LuaFFIPerthreadResponseAction failed inside Lua, returning ServFail: [unknown exception]"); - } - return DNSResponseAction::Action::ServFail; - } - - [[nodiscard]] std::string toString() const override - { - return "Lua FFI per-thread script"; - } - -private: - struct PerThreadState - { - LuaContext d_luaContext; - func_t d_func; - bool d_initialized{false}; - }; - - static std::atomic s_functionsCounter; - static thread_local std::map t_perThreadStates; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const std::string d_functionCode; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const uint64_t d_functionID; -}; - -std::atomic LuaFFIPerThreadResponseAction::s_functionsCounter = 0; -thread_local std::map LuaFFIPerThreadResponseAction::t_perThreadStates; - -class SpoofAction : public DNSAction -{ -public: - SpoofAction(const vector& addrs, const dnsdist::ResponseConfig& responseConfig) : - d_responseConfig(responseConfig), d_addrs(addrs) - { - for (const auto& addr : d_addrs) { - if (addr.isIPv4()) { - d_types.insert(QType::A); - } - else if (addr.isIPv6()) { - d_types.insert(QType::AAAA); - } - } - - if (!d_addrs.empty()) { - d_types.insert(QType::ANY); - } - } - - SpoofAction(const DNSName& cname, const dnsdist::ResponseConfig& responseConfig) : - d_responseConfig(responseConfig), d_cname(cname) - { - } - - SpoofAction(const PacketBuffer& rawresponse) : - d_raw(rawresponse) - { - } - - SpoofAction(const vector& raws, std::optional typeForAny, const dnsdist::ResponseConfig& responseConfig) : - d_responseConfig(responseConfig), d_rawResponses(raws), d_rawTypeForAny(typeForAny) - { - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, string* ruleresult) const override; - - string toString() const override - { - string ret = "spoof in "; - if (!d_cname.empty()) { - ret += d_cname.toString() + " "; - } - if (d_rawResponses.size() > 0) { - ret += "raw bytes "; - } - else { - for (const auto& a : d_addrs) - ret += a.toString() + " "; - } - return ret; - } - -private: - dnsdist::ResponseConfig d_responseConfig; - std::vector d_addrs; - std::unordered_set d_types; - std::vector d_rawResponses; - PacketBuffer d_raw; - DNSName d_cname; - std::optional d_rawTypeForAny{}; -}; - -DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const -{ - uint16_t qtype = dnsquestion->ids.qtype; - // do we even have a response? - if (d_cname.empty() && d_rawResponses.empty() && - // make sure pre-forged response is greater than sizeof(dnsheader) - (d_raw.size() < sizeof(dnsheader)) && d_types.count(qtype) == 0) { - return Action::None; - } - - if (d_raw.size() >= sizeof(dnsheader)) { - dnsdist::self_answers::generateAnswerFromRawPacket(*dnsquestion, d_raw); - return Action::HeaderModify; - } - - if (!d_cname.empty()) { - if (dnsdist::self_answers::generateAnswerFromCNAME(*dnsquestion, d_cname, d_responseConfig)) { - return Action::HeaderModify; - } - } - else if (!d_rawResponses.empty()) { - if (dnsdist::self_answers::generateAnswerFromRDataEntries(*dnsquestion, d_rawResponses, d_rawTypeForAny, d_responseConfig)) { - return Action::HeaderModify; - } - } - else { - if (dnsdist::self_answers::generateAnswerFromIPAddresses(*dnsquestion, d_addrs, d_responseConfig)) { - return Action::HeaderModify; - } - } - - return Action::None; -} - -class SetMacAddrAction : public DNSAction -{ -public: - // this action does not stop the processing - SetMacAddrAction(uint16_t code) : - d_code(code) - { - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - dnsdist::MacAddress mac{}; - int res = dnsdist::MacAddressesCache::get(dnsquestion->ids.origRemote, mac.data(), mac.size()); - if (res != 0) { - return Action::None; - } - - std::string optRData; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - generateEDNSOption(d_code, reinterpret_cast(mac.data()), optRData); - - if (dnsquestion->getHeader()->arcount > 0) { - bool ednsAdded = false; - bool optionAdded = false; - PacketBuffer newContent; - newContent.reserve(dnsquestion->getData().size()); - - if (!slowRewriteEDNSOptionInQueryWithRecords(dnsquestion->getData(), newContent, ednsAdded, d_code, optionAdded, true, optRData)) { - return Action::None; - } - - if (newContent.size() > dnsquestion->getMaximumSize()) { - return Action::None; - } - - dnsquestion->getMutableData() = std::move(newContent); - if (!dnsquestion->ids.ednsAdded && ednsAdded) { - dnsquestion->ids.ednsAdded = true; - } - - return Action::None; - } - - auto& data = dnsquestion->getMutableData(); - if (generateOptRR(optRData, data, dnsquestion->getMaximumSize(), dnsdist::configuration::s_EdnsUDPPayloadSize, 0, false)) { - dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) { - header.arcount = htons(1); - return true; - }); - // make sure that any EDNS sent by the backend is removed before forwarding the response to the client - dnsquestion->ids.ednsAdded = true; - } - - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "add EDNS MAC (code=" + std::to_string(d_code) + ")"; - } - -private: - uint16_t d_code{3}; -}; - -class SetEDNSOptionAction : public DNSAction -{ -public: - // this action does not stop the processing - SetEDNSOptionAction(uint16_t code, std::string data) : - d_code(code), d_data(std::move(data)) - { - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - setEDNSOption(*dnsquestion, d_code, d_data); - return Action::None; - } - - [[nodiscard]] std::string toString() const override - { - return "add EDNS Option (code=" + std::to_string(d_code) + ")"; - } - -private: - uint16_t d_code; - std::string d_data; -}; - -class SetNoRecurseAction : public DNSAction -{ -public: - // this action does not stop the processing - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) { - header.rd = false; - return true; - }); - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "set rd=0"; - } -}; - -class LogAction : public DNSAction, public boost::noncopyable -{ -public: - // this action does not stop the processing - LogAction() = default; - - LogAction(const std::string& str, bool binary = true, bool append = false, bool buffered = true, bool verboseOnly = true, bool includeTimestamp = false) : - d_fname(str), d_binary(binary), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp), d_append(append), d_buffered(buffered) - { - if (str.empty()) { - return; - } - - if (!reopenLogFile()) { - throw std::runtime_error("Unable to open file '" + str + "' for logging: " + stringerror()); - } - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - auto filepointer = std::atomic_load_explicit(&d_fp, std::memory_order_acquire); - if (!filepointer) { - if (!d_verboseOnly || dnsdist::configuration::getCurrentRuntimeConfiguration().d_verbose) { - if (d_includeTimestamp) { - infolog("[%u.%u] Packet from %s for %s %s with id %d", static_cast(dnsquestion->getQueryRealTime().tv_sec), static_cast(dnsquestion->getQueryRealTime().tv_nsec), dnsquestion->ids.origRemote.toStringWithPort(), dnsquestion->ids.qname.toString(), QType(dnsquestion->ids.qtype).toString(), dnsquestion->getHeader()->id); - } - else { - infolog("Packet from %s for %s %s with id %d", dnsquestion->ids.origRemote.toStringWithPort(), dnsquestion->ids.qname.toString(), QType(dnsquestion->ids.qtype).toString(), dnsquestion->getHeader()->id); - } - } - } - else { - if (d_binary) { - const auto& out = dnsquestion->ids.qname.getStorage(); - if (d_includeTimestamp) { - auto tv_sec = static_cast(dnsquestion->getQueryRealTime().tv_sec); - auto tv_nsec = static_cast(dnsquestion->getQueryRealTime().tv_nsec); - fwrite(&tv_sec, sizeof(tv_sec), 1, filepointer.get()); - fwrite(&tv_nsec, sizeof(tv_nsec), 1, filepointer.get()); - } - uint16_t queryId = dnsquestion->getHeader()->id; - fwrite(&queryId, sizeof(queryId), 1, filepointer.get()); - fwrite(out.c_str(), 1, out.size(), filepointer.get()); - fwrite(&dnsquestion->ids.qtype, sizeof(dnsquestion->ids.qtype), 1, filepointer.get()); - fwrite(&dnsquestion->ids.origRemote.sin4.sin_family, sizeof(dnsquestion->ids.origRemote.sin4.sin_family), 1, filepointer.get()); - if (dnsquestion->ids.origRemote.sin4.sin_family == AF_INET) { - fwrite(&dnsquestion->ids.origRemote.sin4.sin_addr.s_addr, sizeof(dnsquestion->ids.origRemote.sin4.sin_addr.s_addr), 1, filepointer.get()); - } - else if (dnsquestion->ids.origRemote.sin4.sin_family == AF_INET6) { - fwrite(&dnsquestion->ids.origRemote.sin6.sin6_addr.s6_addr, sizeof(dnsquestion->ids.origRemote.sin6.sin6_addr.s6_addr), 1, filepointer.get()); - } - fwrite(&dnsquestion->ids.origRemote.sin4.sin_port, sizeof(dnsquestion->ids.origRemote.sin4.sin_port), 1, filepointer.get()); - } - else { - if (d_includeTimestamp) { - fprintf(filepointer.get(), "[%llu.%lu] Packet from %s for %s %s with id %u\n", static_cast(dnsquestion->getQueryRealTime().tv_sec), static_cast(dnsquestion->getQueryRealTime().tv_nsec), dnsquestion->ids.origRemote.toStringWithPort().c_str(), dnsquestion->ids.qname.toString().c_str(), QType(dnsquestion->ids.qtype).toString().c_str(), dnsquestion->getHeader()->id); - } - else { - fprintf(filepointer.get(), "Packet from %s for %s %s with id %u\n", dnsquestion->ids.origRemote.toStringWithPort().c_str(), dnsquestion->ids.qname.toString().c_str(), QType(dnsquestion->ids.qtype).toString().c_str(), dnsquestion->getHeader()->id); - } - } - } - return Action::None; - } - - [[nodiscard]] std::string toString() const override - { - if (!d_fname.empty()) { - return "log to " + d_fname; - } - return "log"; - } - - void reload() override - { - if (!reopenLogFile()) { - warnlog("Unable to open file '%s' for logging: %s", d_fname, stringerror()); - } - } - -private: - bool reopenLogFile() - { - // we are using a naked pointer here because we don't want fclose to be called - // with a nullptr, which would happen if we constructor a shared_ptr with fclose - // as a custom deleter and nullptr as a FILE* - // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) - auto* nfp = fopen(d_fname.c_str(), d_append ? "a+" : "w"); - if (nfp == nullptr) { - /* don't fall on our sword when reopening */ - return false; - } - - auto filepointer = std::shared_ptr(nfp, fclose); - nfp = nullptr; - - if (!d_buffered) { - setbuf(filepointer.get(), nullptr); - } - - std::atomic_store_explicit(&d_fp, std::move(filepointer), std::memory_order_release); - return true; - } - - std::string d_fname; - std::shared_ptr d_fp{nullptr}; - bool d_binary{true}; - bool d_verboseOnly{true}; - bool d_includeTimestamp{false}; - bool d_append{false}; - bool d_buffered{true}; -}; - -class LogResponseAction : public DNSResponseAction, public boost::noncopyable -{ -public: - LogResponseAction() = default; - - LogResponseAction(const std::string& str, bool append = false, bool buffered = true, bool verboseOnly = true, bool includeTimestamp = false) : - d_fname(str), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp), d_append(append), d_buffered(buffered) - { - if (str.empty()) { - return; - } - - if (!reopenLogFile()) { - throw std::runtime_error("Unable to open file '" + str + "' for logging: " + stringerror()); - } - } - - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - auto filepointer = std::atomic_load_explicit(&d_fp, std::memory_order_acquire); - if (!filepointer) { - if (!d_verboseOnly || dnsdist::configuration::getCurrentRuntimeConfiguration().d_verbose) { - if (d_includeTimestamp) { - infolog("[%u.%u] Answer to %s for %s %s (%s) with id %u", static_cast(response->getQueryRealTime().tv_sec), static_cast(response->getQueryRealTime().tv_nsec), response->ids.origRemote.toStringWithPort(), response->ids.qname.toString(), QType(response->ids.qtype).toString(), RCode::to_s(response->getHeader()->rcode), response->getHeader()->id); - } - else { - infolog("Answer to %s for %s %s (%s) with id %u", response->ids.origRemote.toStringWithPort(), response->ids.qname.toString(), QType(response->ids.qtype).toString(), RCode::to_s(response->getHeader()->rcode), response->getHeader()->id); - } - } - } - else { - if (d_includeTimestamp) { - fprintf(filepointer.get(), "[%llu.%lu] Answer to %s for %s %s (%s) with id %u\n", static_cast(response->getQueryRealTime().tv_sec), static_cast(response->getQueryRealTime().tv_nsec), response->ids.origRemote.toStringWithPort().c_str(), response->ids.qname.toString().c_str(), QType(response->ids.qtype).toString().c_str(), RCode::to_s(response->getHeader()->rcode).c_str(), response->getHeader()->id); - } - else { - fprintf(filepointer.get(), "Answer to %s for %s %s (%s) with id %u\n", response->ids.origRemote.toStringWithPort().c_str(), response->ids.qname.toString().c_str(), QType(response->ids.qtype).toString().c_str(), RCode::to_s(response->getHeader()->rcode).c_str(), response->getHeader()->id); - } - } - return Action::None; - } - - [[nodiscard]] std::string toString() const override - { - if (!d_fname.empty()) { - return "log to " + d_fname; - } - return "log"; - } - - void reload() override - { - if (!reopenLogFile()) { - warnlog("Unable to open file '%s' for logging: %s", d_fname, stringerror()); - } - } - -private: - bool reopenLogFile() - { - // we are using a naked pointer here because we don't want fclose to be called - // with a nullptr, which would happen if we constructor a shared_ptr with fclose - // as a custom deleter and nullptr as a FILE* - // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) - auto* nfp = fopen(d_fname.c_str(), d_append ? "a+" : "w"); - if (nfp == nullptr) { - /* don't fall on our sword when reopening */ - return false; - } - - auto filepointer = std::shared_ptr(nfp, fclose); - nfp = nullptr; - - if (!d_buffered) { - setbuf(filepointer.get(), nullptr); - } - - std::atomic_store_explicit(&d_fp, std::move(filepointer), std::memory_order_release); - return true; - } - - std::string d_fname; - std::shared_ptr d_fp{nullptr}; - bool d_verboseOnly{true}; - bool d_includeTimestamp{false}; - bool d_append{false}; - bool d_buffered{true}; -}; - -class SetDisableValidationAction : public DNSAction -{ -public: - // this action does not stop the processing - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [](dnsheader& header) { - header.cd = true; - return true; - }); - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "set cd=1"; - } -}; - -class SetSkipCacheAction : public DNSAction -{ -public: - // this action does not stop the processing - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - dnsquestion->ids.skipCache = true; - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "skip cache"; - } -}; - -class SetSkipCacheResponseAction : public DNSResponseAction -{ -public: - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - response->ids.skipCache = true; - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "skip cache"; - } -}; - -class SetTempFailureCacheTTLAction : public DNSAction -{ -public: - // this action does not stop the processing - SetTempFailureCacheTTLAction(uint32_t ttl) : - d_ttl(ttl) - { - } - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - dnsquestion->ids.tempFailureTTL = d_ttl; - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "set tempfailure cache ttl to " + std::to_string(d_ttl); - } - -private: - uint32_t d_ttl; -}; - -class SetECSPrefixLengthAction : public DNSAction -{ -public: - // this action does not stop the processing - SetECSPrefixLengthAction(uint16_t v4Length, uint16_t v6Length) : - d_v4PrefixLength(v4Length), d_v6PrefixLength(v6Length) - { - } - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - dnsquestion->ecsPrefixLength = dnsquestion->ids.origRemote.sin4.sin_family == AF_INET ? d_v4PrefixLength : d_v6PrefixLength; - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "set ECS prefix length to " + std::to_string(d_v4PrefixLength) + "/" + std::to_string(d_v6PrefixLength); - } - -private: - uint16_t d_v4PrefixLength; - uint16_t d_v6PrefixLength; -}; - -class SetECSOverrideAction : public DNSAction -{ -public: - // this action does not stop the processing - SetECSOverrideAction(bool ecsOverride) : - d_ecsOverride(ecsOverride) - { - } - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - dnsquestion->ecsOverride = d_ecsOverride; - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "set ECS override to " + std::to_string(static_cast(d_ecsOverride)); - } - -private: - bool d_ecsOverride; -}; - -class SetDisableECSAction : public DNSAction -{ -public: - // this action does not stop the processing - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - dnsquestion->useECS = false; - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "disable ECS"; - } -}; - -class SetECSAction : public DNSAction -{ -public: - // this action does not stop the processing - SetECSAction(const Netmask& v4Netmask) : - d_v4(v4Netmask), d_hasV6(false) - { - } - - SetECSAction(const Netmask& v4Netmask, const Netmask& v6Netmask) : - d_v4(v4Netmask), d_v6(v6Netmask), d_hasV6(true) - { - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - if (d_hasV6) { - dnsquestion->ecs = std::make_unique(dnsquestion->ids.origRemote.isIPv4() ? d_v4 : d_v6); - } - else { - dnsquestion->ecs = std::make_unique(d_v4); - } - - return Action::None; - } - - [[nodiscard]] std::string toString() const override - { - std::string result = "set ECS to " + d_v4.toString(); - if (d_hasV6) { - result += " / " + d_v6.toString(); - } - return result; - } - -private: - Netmask d_v4; - Netmask d_v6; - bool d_hasV6; -}; - -#ifndef DISABLE_PROTOBUF -static DnstapMessage::ProtocolType ProtocolToDNSTap(dnsdist::Protocol protocol) -{ - if (protocol == dnsdist::Protocol::DoUDP) { - return DnstapMessage::ProtocolType::DoUDP; - } - if (protocol == dnsdist::Protocol::DoTCP) { - return DnstapMessage::ProtocolType::DoTCP; - } - if (protocol == dnsdist::Protocol::DoT) { - return DnstapMessage::ProtocolType::DoT; - } - if (protocol == dnsdist::Protocol::DoH || protocol == dnsdist::Protocol::DoH3) { - return DnstapMessage::ProtocolType::DoH; - } - if (protocol == dnsdist::Protocol::DNSCryptUDP) { - return DnstapMessage::ProtocolType::DNSCryptUDP; - } - if (protocol == dnsdist::Protocol::DNSCryptTCP) { - return DnstapMessage::ProtocolType::DNSCryptTCP; - } - if (protocol == dnsdist::Protocol::DoQ) { - return DnstapMessage::ProtocolType::DoQ; - } - throw std::runtime_error("Unhandled protocol for dnstap: " + protocol.toPrettyString()); -} - -static void remoteLoggerQueueData(RemoteLoggerInterface& remoteLogger, const std::string& data) -{ - auto ret = remoteLogger.queueData(data); - - switch (ret) { - case RemoteLoggerInterface::Result::Queued: - break; - case RemoteLoggerInterface::Result::PipeFull: { - vinfolog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret)); - break; - } - case RemoteLoggerInterface::Result::TooLarge: { - warnlog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret)); - break; - } - case RemoteLoggerInterface::Result::OtherError: - warnlog("%s: %s", remoteLogger.name(), RemoteLoggerInterface::toErrorString(ret)); - } -} - -class DnstapLogAction : public DNSAction, public boost::noncopyable -{ -public: - // this action does not stop the processing - DnstapLogAction(std::string identity, std::shared_ptr& logger, boost::optional> alterFunc) : - d_identity(std::move(identity)), d_logger(logger), d_alterFunc(std::move(alterFunc)) - { - } - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - static thread_local std::string data; - data.clear(); - - DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(dnsquestion->getProtocol()); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - DnstapMessage message(std::move(data), !dnsquestion->getHeader()->qr ? DnstapMessage::MessageType::client_query : DnstapMessage::MessageType::client_response, d_identity, &dnsquestion->ids.origRemote, &dnsquestion->ids.origDest, protocol, reinterpret_cast(dnsquestion->getData().data()), dnsquestion->getData().size(), &dnsquestion->getQueryRealTime(), nullptr); - { - if (d_alterFunc) { - auto lock = g_lua.lock(); - (*d_alterFunc)(dnsquestion, &message); - } - } - - data = message.getBuffer(); - remoteLoggerQueueData(*d_logger, data); - - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "remote log as dnstap to " + (d_logger ? d_logger->toString() : ""); - } - -private: - std::string d_identity; - std::shared_ptr d_logger; - boost::optional> d_alterFunc; -}; - -namespace -{ -void addMetaDataToProtobuf(DNSDistProtoBufMessage& message, const DNSQuestion& dnsquestion, const std::vector>& metas) -{ - for (const auto& [name, meta] : metas) { - message.addMeta(name, meta.getValues(dnsquestion), {}); - } -} - -void addTagsToProtobuf(DNSDistProtoBufMessage& message, const DNSQuestion& dnsquestion, const std::unordered_set& allowed) -{ - if (!dnsquestion.ids.qTag) { - return; - } - - for (const auto& [key, value] : *dnsquestion.ids.qTag) { - if (!allowed.empty() && allowed.count(key) == 0) { - continue; - } - - if (value.empty()) { - message.addTag(key); - } - else { - auto tag = key; - tag.append(":"); - tag.append(value); - message.addTag(tag); - } - } -} - -void addExtendedDNSErrorToProtobuf(DNSDistProtoBufMessage& message, const DNSResponse& response, const std::string& metaKey) -{ - auto [infoCode, extraText] = dnsdist::edns::getExtendedDNSError(response.getData()); - if (!infoCode) { - return; - } - - if (extraText) { - message.addMeta(metaKey, {*extraText}, {*infoCode}); - } - else { - message.addMeta(metaKey, {}, {*infoCode}); - } -} -} - -struct RemoteLogActionConfiguration -{ - std::vector> metas; - std::optional> tagsToExport{std::nullopt}; - boost::optional> alterQueryFunc{boost::none}; - boost::optional> alterResponseFunc{boost::none}; - std::shared_ptr logger; - std::string serverID; - std::string ipEncryptKey; - std::optional exportExtendedErrorsToMeta{std::nullopt}; - bool includeCNAME{false}; -}; - -class RemoteLogAction : public DNSAction, public boost::noncopyable -{ -public: - // this action does not stop the processing - RemoteLogAction(RemoteLogActionConfiguration& config) : - d_tagsToExport(std::move(config.tagsToExport)), d_metas(std::move(config.metas)), d_logger(config.logger), d_alterFunc(std::move(config.alterQueryFunc)), d_serverID(config.serverID), d_ipEncryptKey(config.ipEncryptKey) - { - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - if (!dnsquestion->ids.d_protoBufData) { - dnsquestion->ids.d_protoBufData = std::make_unique(); - } - if (!dnsquestion->ids.d_protoBufData->uniqueId) { - dnsquestion->ids.d_protoBufData->uniqueId = getUniqueID(); - } - - DNSDistProtoBufMessage message(*dnsquestion); - if (!d_serverID.empty()) { - message.setServerIdentity(d_serverID); - } - -#ifdef HAVE_IPCIPHER - if (!d_ipEncryptKey.empty()) { - message.setRequestor(encryptCA(dnsquestion->ids.origRemote, d_ipEncryptKey)); - } -#endif /* HAVE_IPCIPHER */ - - if (d_tagsToExport) { - addTagsToProtobuf(message, *dnsquestion, *d_tagsToExport); - } - - addMetaDataToProtobuf(message, *dnsquestion, d_metas); - - if (d_alterFunc) { - auto lock = g_lua.lock(); - (*d_alterFunc)(dnsquestion, &message); - } - - static thread_local std::string data; - data.clear(); - message.serialize(data); - remoteLoggerQueueData(*d_logger, data); - - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "remote log to " + (d_logger ? d_logger->toString() : ""); - } - -private: - std::optional> d_tagsToExport; - std::vector> d_metas; - std::shared_ptr d_logger; - boost::optional> d_alterFunc; - std::string d_serverID; - std::string d_ipEncryptKey; -}; - -#endif /* DISABLE_PROTOBUF */ - -class SNMPTrapAction : public DNSAction -{ -public: - // this action does not stop the processing - SNMPTrapAction(std::string reason) : - d_reason(std::move(reason)) - { - } - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - if (g_snmpAgent != nullptr && dnsdist::configuration::getImmutableConfiguration().d_snmpTrapsEnabled) { - g_snmpAgent->sendDNSTrap(*dnsquestion, d_reason); - } - - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "send SNMP trap"; - } - -private: - std::string d_reason; -}; - -class SetTagAction : public DNSAction -{ -public: - // this action does not stop the processing - SetTagAction(std::string tag, std::string value) : - d_tag(std::move(tag)), d_value(std::move(value)) - { - } - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - dnsquestion->setTag(d_tag, d_value); - - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "set tag '" + d_tag + "' to value '" + d_value + "'"; - } - -private: - std::string d_tag; - std::string d_value; -}; - -#ifndef DISABLE_PROTOBUF -class DnstapLogResponseAction : public DNSResponseAction, public boost::noncopyable -{ -public: - // this action does not stop the processing - DnstapLogResponseAction(std::string identity, std::shared_ptr& logger, boost::optional> alterFunc) : - d_identity(std::move(identity)), d_logger(logger), d_alterFunc(std::move(alterFunc)) - { - } - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - static thread_local std::string data; - struct timespec now = {}; - gettime(&now, true); - data.clear(); - - DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(response->getProtocol()); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - DnstapMessage message(std::move(data), DnstapMessage::MessageType::client_response, d_identity, &response->ids.origRemote, &response->ids.origDest, protocol, reinterpret_cast(response->getData().data()), response->getData().size(), &response->getQueryRealTime(), &now); - { - if (d_alterFunc) { - auto lock = g_lua.lock(); - (*d_alterFunc)(response, &message); - } - } - - data = message.getBuffer(); - remoteLoggerQueueData(*d_logger, data); - - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "log response as dnstap to " + (d_logger ? d_logger->toString() : ""); - } - -private: - std::string d_identity; - std::shared_ptr d_logger; - boost::optional> d_alterFunc; -}; - -class RemoteLogResponseAction : public DNSResponseAction, public boost::noncopyable -{ -public: - // this action does not stop the processing - RemoteLogResponseAction(RemoteLogActionConfiguration& config) : - d_tagsToExport(std::move(config.tagsToExport)), d_metas(std::move(config.metas)), d_logger(config.logger), d_alterFunc(std::move(config.alterResponseFunc)), d_serverID(config.serverID), d_ipEncryptKey(config.ipEncryptKey), d_exportExtendedErrorsToMeta(std::move(config.exportExtendedErrorsToMeta)), d_includeCNAME(config.includeCNAME) - { - } - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - if (!response->ids.d_protoBufData) { - response->ids.d_protoBufData = std::make_unique(); - } - if (!response->ids.d_protoBufData->uniqueId) { - response->ids.d_protoBufData->uniqueId = getUniqueID(); - } - - DNSDistProtoBufMessage message(*response, d_includeCNAME); - if (!d_serverID.empty()) { - message.setServerIdentity(d_serverID); - } - -#ifdef HAVE_IPCIPHER - if (!d_ipEncryptKey.empty()) { - message.setRequestor(encryptCA(response->ids.origRemote, d_ipEncryptKey)); - } -#endif /* HAVE_IPCIPHER */ - - if (d_tagsToExport) { - addTagsToProtobuf(message, *response, *d_tagsToExport); - } - - addMetaDataToProtobuf(message, *response, d_metas); - - if (d_exportExtendedErrorsToMeta) { - addExtendedDNSErrorToProtobuf(message, *response, *d_exportExtendedErrorsToMeta); - } - - if (d_alterFunc) { - auto lock = g_lua.lock(); - (*d_alterFunc)(response, &message); - } - - static thread_local std::string data; - data.clear(); - message.serialize(data); - d_logger->queueData(data); - - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "remote log response to " + (d_logger ? d_logger->toString() : ""); - } - -private: - std::optional> d_tagsToExport; - std::vector> d_metas; - std::shared_ptr d_logger; - boost::optional> d_alterFunc; - std::string d_serverID; - std::string d_ipEncryptKey; - std::optional d_exportExtendedErrorsToMeta{std::nullopt}; - bool d_includeCNAME; -}; - -#endif /* DISABLE_PROTOBUF */ - -class DropResponseAction : public DNSResponseAction -{ -public: - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - return Action::Drop; - } - [[nodiscard]] std::string toString() const override - { - return "drop"; - } -}; - -class AllowResponseAction : public DNSResponseAction -{ -public: - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - return Action::Allow; - } - [[nodiscard]] std::string toString() const override - { - return "allow"; - } -}; - -class DelayResponseAction : public DNSResponseAction -{ -public: - DelayResponseAction(int msec) : - d_msec(msec) - { - } - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - *ruleresult = std::to_string(d_msec); - return Action::Delay; - } - [[nodiscard]] std::string toString() const override - { - return "delay by " + std::to_string(d_msec) + " ms"; - } - -private: - int d_msec; -}; - -#ifdef HAVE_NET_SNMP -class SNMPTrapResponseAction : public DNSResponseAction -{ -public: - // this action does not stop the processing - SNMPTrapResponseAction(std::string reason) : - d_reason(std::move(reason)) - { - } - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - if (g_snmpAgent != nullptr && dnsdist::configuration::getImmutableConfiguration().d_snmpTrapsEnabled) { - g_snmpAgent->sendDNSTrap(*response, d_reason); - } - - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "send SNMP trap"; - } - -private: - std::string d_reason; -}; -#endif /* HAVE_NET_SNMP */ - -class SetTagResponseAction : public DNSResponseAction -{ -public: - // this action does not stop the processing - SetTagResponseAction(std::string tag, std::string value) : - d_tag(std::move(tag)), d_value(std::move(value)) - { - } - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - response->setTag(d_tag, d_value); - - return Action::None; - } - [[nodiscard]] std::string toString() const override - { - return "set tag '" + d_tag + "' to value '" + d_value + "'"; - } - -private: - std::string d_tag; - std::string d_value; -}; - -class ClearRecordTypesResponseAction : public DNSResponseAction, public boost::noncopyable -{ -public: - ClearRecordTypesResponseAction(std::unordered_set qtypes) : - d_qtypes(std::move(qtypes)) - { - } - - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - if (!d_qtypes.empty()) { - clearDNSPacketRecordTypes(response->getMutableData(), d_qtypes); - } - return DNSResponseAction::Action::None; - } - - [[nodiscard]] std::string toString() const override - { - return "clear record types"; - } - -private: - std::unordered_set d_qtypes{}; -}; - -class ContinueAction : public DNSAction -{ -public: - // this action does not stop the processing - ContinueAction(std::shared_ptr& action) : - d_action(action) - { - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - if (d_action) { - /* call the action */ - auto action = (*d_action)(dnsquestion, ruleresult); - bool drop = false; - /* apply the changes if needed (pool selection, flags, etc */ - processRulesResult(action, *dnsquestion, *ruleresult, drop); - } - - /* but ignore the resulting action no matter what */ - return Action::None; - } - - [[nodiscard]] std::string toString() const override - { - if (d_action) { - return "continue after: " + (d_action ? d_action->toString() : ""); - } - return "no op"; - } - -private: - std::shared_ptr d_action; -}; - -#if defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_HTTP3) -class HTTPStatusAction : public DNSAction -{ -public: - HTTPStatusAction(int code, PacketBuffer body, std::string contentType, dnsdist::ResponseConfig responseConfig) : - d_responseConfig(responseConfig), d_body(std::move(body)), d_contentType(std::move(contentType)), d_code(code) - { - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { -#if defined(HAVE_DNS_OVER_HTTPS) - if (dnsquestion->ids.du) { - dnsquestion->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType); - dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { - header.qr = true; // for good measure - setResponseHeadersFromConfig(header, d_responseConfig); - return true; - }); - return Action::HeaderModify; - } -#endif /* HAVE_DNS_OVER_HTTPS */ -#if defined(HAVE_DNS_OVER_HTTP3) - if (dnsquestion->ids.doh3u) { - dnsquestion->ids.doh3u->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType); - dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { - header.qr = true; // for good measure - setResponseHeadersFromConfig(header, d_responseConfig); - return true; - }); - return Action::HeaderModify; - } -#endif /* HAVE_DNS_OVER_HTTP3 */ - return Action::None; - } - - [[nodiscard]] std::string toString() const override - { - return "return an HTTP status of " + std::to_string(d_code); - } - -private: - dnsdist::ResponseConfig d_responseConfig; - PacketBuffer d_body; - std::string d_contentType; - int d_code; -}; -#endif /* HAVE_DNS_OVER_HTTPS || HAVE_DNS_OVER_HTTP3 */ - -#if defined(HAVE_LMDB) || defined(HAVE_CDB) -class KeyValueStoreLookupAction : public DNSAction -{ -public: - // this action does not stop the processing - KeyValueStoreLookupAction(std::shared_ptr& kvs, std::shared_ptr& lookupKey, std::string destinationTag) : - d_kvs(kvs), d_key(lookupKey), d_tag(std::move(destinationTag)) - { - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - std::vector keys = d_key->getKeys(*dnsquestion); - std::string result; - for (const auto& key : keys) { - if (d_kvs->getValue(key, result)) { - break; - } - } - - dnsquestion->setTag(d_tag, std::move(result)); - - return Action::None; - } - - [[nodiscard]] std::string toString() const override - { - return "lookup key-value store based on '" + d_key->toString() + "' and set the result in tag '" + d_tag + "'"; - } - -private: - std::shared_ptr d_kvs; - std::shared_ptr d_key; - std::string d_tag; -}; - -class KeyValueStoreRangeLookupAction : public DNSAction -{ -public: - // this action does not stop the processing - KeyValueStoreRangeLookupAction(std::shared_ptr& kvs, std::shared_ptr& lookupKey, std::string destinationTag) : - d_kvs(kvs), d_key(lookupKey), d_tag(std::move(destinationTag)) - { - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - std::vector keys = d_key->getKeys(*dnsquestion); - std::string result; - for (const auto& key : keys) { - if (d_kvs->getRangeValue(key, result)) { - break; - } - } - - dnsquestion->setTag(d_tag, std::move(result)); - - return Action::None; - } - - [[nodiscard]] std::string toString() const override - { - return "do a range-based lookup in key-value store based on '" + d_key->toString() + "' and set the result in tag '" + d_tag + "'"; - } - -private: - std::shared_ptr d_kvs; - std::shared_ptr d_key; - std::string d_tag; -}; -#endif /* defined(HAVE_LMDB) || defined(HAVE_CDB) */ - -class MaxReturnedTTLAction : public DNSAction -{ -public: - MaxReturnedTTLAction(uint32_t cap) : - d_cap(cap) - { - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - dnsquestion->ids.ttlCap = d_cap; - return DNSAction::Action::None; - } - - [[nodiscard]] std::string toString() const override - { - return "cap the TTL of the returned response to " + std::to_string(d_cap); - } - -private: - uint32_t d_cap; -}; - -class MaxReturnedTTLResponseAction : public DNSResponseAction -{ -public: - MaxReturnedTTLResponseAction(uint32_t cap) : - d_cap(cap) - { - } - - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - response->ids.ttlCap = d_cap; - return DNSResponseAction::Action::None; - } - - [[nodiscard]] std::string toString() const override - { - return "cap the TTL of the returned response to " + std::to_string(d_cap); - } - -private: - uint32_t d_cap; -}; - -class NegativeAndSOAAction : public DNSAction -{ -public: - struct SOAParams - { - uint32_t serial; - uint32_t refresh; - uint32_t retry; - uint32_t expire; - uint32_t minimum; - }; - - NegativeAndSOAAction(bool nxd, DNSName zone, uint32_t ttl, DNSName mname, DNSName rname, SOAParams params, bool soaInAuthoritySection, dnsdist::ResponseConfig responseConfig) : - d_responseConfig(responseConfig), d_zone(std::move(zone)), d_mname(std::move(mname)), d_rname(std::move(rname)), d_ttl(ttl), d_params(params), d_nxd(nxd), d_soaInAuthoritySection(soaInAuthoritySection) - { - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - if (!setNegativeAndAdditionalSOA(*dnsquestion, d_nxd, d_zone, d_ttl, d_mname, d_rname, d_params.serial, d_params.refresh, d_params.retry, d_params.expire, d_params.minimum, d_soaInAuthoritySection)) { - return Action::None; - } - - dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) { - setResponseHeadersFromConfig(header, d_responseConfig); - return true; - }); - - return Action::Allow; - } - - [[nodiscard]] std::string toString() const override - { - return std::string(d_nxd ? "NXD" : "NODATA") + " with SOA"; - } - -private: - dnsdist::ResponseConfig d_responseConfig; - - DNSName d_zone; - DNSName d_mname; - DNSName d_rname; - uint32_t d_ttl; - SOAParams d_params; - bool d_nxd; - bool d_soaInAuthoritySection; -}; - -class SetProxyProtocolValuesAction : public DNSAction -{ -public: - // this action does not stop the processing - SetProxyProtocolValuesAction(const std::vector>& values) - { - d_values.reserve(values.size()); - for (const auto& value : values) { - d_values.push_back({value.second, value.first}); - } - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - if (!dnsquestion->proxyProtocolValues) { - dnsquestion->proxyProtocolValues = make_unique>(); - } - - *(dnsquestion->proxyProtocolValues) = d_values; - - return Action::None; - } - - [[nodiscard]] std::string toString() const override - { - return "set Proxy-Protocol values"; - } - -private: - std::vector d_values; -}; - -class SetAdditionalProxyProtocolValueAction : public DNSAction -{ -public: - // this action does not stop the processing - SetAdditionalProxyProtocolValueAction(uint8_t type, std::string value) : - d_value(std::move(value)), d_type(type) - { - } - - DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override - { - if (!dnsquestion->proxyProtocolValues) { - dnsquestion->proxyProtocolValues = make_unique>(); - } - - dnsquestion->proxyProtocolValues->push_back({d_value, d_type}); - - return Action::None; - } - - [[nodiscard]] std::string toString() const override - { - return "add a Proxy-Protocol value of type " + std::to_string(d_type); - } - -private: - std::string d_value; - uint8_t d_type; -}; - -class SetReducedTTLResponseAction : public DNSResponseAction, public boost::noncopyable -{ -public: - // this action does not stop the processing - SetReducedTTLResponseAction(uint8_t percentage) : - d_ratio(percentage / 100.0) - { - } - - DNSResponseAction::Action operator()(DNSResponse* response, std::string* ruleresult) const override - { - // NOLINTNEXTLINE(bugprone-easily-swappable-parameters) - auto visitor = [&](uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl) { - return ttl * d_ratio; - }; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - editDNSPacketTTL(reinterpret_cast(response->getMutableData().data()), response->getData().size(), visitor); - return DNSResponseAction::Action::None; - } - - [[nodiscard]] std::string toString() const override - { - return "reduce ttl to " + std::to_string(d_ratio * 100) + " percent of its value"; - } - -private: - double d_ratio{1.0}; -}; - -class SetExtendedDNSErrorAction : public DNSAction -{ -public: - // this action does not stop the processing - SetExtendedDNSErrorAction(uint16_t infoCode, const std::string& extraText) - { - d_ede.infoCode = infoCode; - d_ede.extraText = extraText; - } - - DNSAction::Action operator()(DNSQuestion* dnsQuestion, std::string* ruleresult) const override - { - dnsQuestion->ids.d_extendedError = std::make_unique(d_ede); - - return DNSAction::Action::None; - } - - [[nodiscard]] std::string toString() const override - { - return "set EDNS Extended DNS Error to " + std::to_string(d_ede.infoCode) + (d_ede.extraText.empty() ? std::string() : std::string(": \"") + d_ede.extraText + std::string("\"")); - } - -private: - EDNSExtendedError d_ede; -}; - -class SetExtendedDNSErrorResponseAction : public DNSResponseAction -{ -public: - // this action does not stop the processing - SetExtendedDNSErrorResponseAction(uint16_t infoCode, const std::string& extraText) - { - d_ede.infoCode = infoCode; - d_ede.extraText = extraText; - } - - DNSResponseAction::Action operator()(DNSResponse* dnsResponse, std::string* ruleresult) const override - { - dnsResponse->ids.d_extendedError = std::make_unique(d_ede); - - return DNSResponseAction::Action::None; - } - - [[nodiscard]] std::string toString() const override - { - return "set EDNS Extended DNS Error to " + std::to_string(d_ede.infoCode) + (d_ede.extraText.empty() ? std::string() : std::string(": \"") + d_ede.extraText + std::string("\"")); - } - -private: - EDNSExtendedError d_ede; -}; +#include "dnstap.hh" +#include "remote_logger.hh" template static void addAction(IdentifierT identifier, const luadnsrule_t& var, const std::shared_ptr& action, boost::optional& params) @@ -2362,6 +59,17 @@ static dnsdist::ResponseConfig parseResponseConfig(boost::optional +static std::vector convertLuaArrayToRegular(const LuaArray& luaArray) +{ + std::vector out; + out.reserve(luaArray.size()); + for (const auto& entry : luaArray) { + out.emplace_back(entry.second); + } + return out; +} + // NOLINTNEXTLINE(readability-function-cognitive-complexity): this function declares Lua bindings, even with a good refactoring it will likely blow up the threshold void setupLuaActions(LuaContext& luaCtx) { @@ -2429,45 +137,6 @@ void setupLuaActions(LuaContext& luaCtx) luaCtx.registerFunction("reload", &DNSAction::reload); luaCtx.registerFunction("reload", &DNSResponseAction::reload); - luaCtx.writeFunction("LuaAction", [](LuaAction::func_t func) { - setLuaSideEffect(); - return std::shared_ptr(new LuaAction(std::move(func))); - }); - - luaCtx.writeFunction("LuaFFIAction", [](LuaFFIAction::func_t func) { - setLuaSideEffect(); - return std::shared_ptr(new LuaFFIAction(std::move(func))); - }); - - luaCtx.writeFunction("LuaFFIPerThreadAction", [](const std::string& code) { - setLuaSideEffect(); - return std::shared_ptr(new LuaFFIPerThreadAction(code)); - }); - - luaCtx.writeFunction("SetNoRecurseAction", []() { - return std::shared_ptr(new SetNoRecurseAction); - }); - - luaCtx.writeFunction("SetMacAddrAction", [](int code) { - return std::shared_ptr(new SetMacAddrAction(code)); - }); - - luaCtx.writeFunction("SetEDNSOptionAction", [](int code, const std::string& data) { - return std::shared_ptr(new SetEDNSOptionAction(code, data)); - }); - - luaCtx.writeFunction("PoolAction", [](const std::string& poolname, boost::optional stopProcessing) { - return std::shared_ptr(new PoolAction(poolname, stopProcessing ? *stopProcessing : true)); - }); - - luaCtx.writeFunction("QPSAction", [](int limit) { - return std::shared_ptr(new QPSAction(limit)); - }); - - luaCtx.writeFunction("QPSPoolAction", [](int limit, const std::string& poolname, boost::optional stopProcessing) { - return std::shared_ptr(new QPSPoolAction(limit, poolname, stopProcessing ? *stopProcessing : true)); - }); - luaCtx.writeFunction("SpoofAction", [](LuaTypeOrArrayOf inp, boost::optional vars) { vector addrs; if (auto* ipaddr = boost::get(&inp)) { @@ -2482,21 +151,22 @@ void setupLuaActions(LuaContext& luaCtx) auto responseConfig = parseResponseConfig(vars); checkAllParametersConsumed("SpoofAction", vars); - auto ret = std::shared_ptr(new SpoofAction(addrs, responseConfig)); + auto ret = dnsdist::actions::getSpoofAction(addrs, responseConfig); return ret; }); luaCtx.writeFunction("SpoofSVCAction", [](const LuaArray& parameters, boost::optional vars) { auto responseConfig = parseResponseConfig(vars); checkAllParametersConsumed("SpoofAction", vars); - auto ret = std::shared_ptr(new SpoofSVCAction(parameters, responseConfig)); + auto svcParams = convertLuaArrayToRegular(parameters); + auto ret = dnsdist::actions::getSpoofSVCAction(svcParams, responseConfig); return ret; }); luaCtx.writeFunction("SpoofCNAMEAction", [](const std::string& cname, boost::optional vars) { auto responseConfig = parseResponseConfig(vars); checkAllParametersConsumed("SpoofCNAMEAction", vars); - auto ret = std::shared_ptr(new SpoofAction(DNSName(cname), responseConfig)); + auto ret = dnsdist::actions::getSpoofAction(DNSName(cname), responseConfig); return ret; }); @@ -2522,7 +192,7 @@ void setupLuaActions(LuaContext& luaCtx) } auto responseConfig = parseResponseConfig(vars); checkAllParametersConsumed("SpoofRawAction", vars); - auto ret = std::shared_ptr(new SpoofAction(raws, qtypeForAnyParam, responseConfig)); + auto ret = dnsdist::actions::getSpoofAction(raws, qtypeForAnyParam, responseConfig); return ret; }); @@ -2531,46 +201,10 @@ void setupLuaActions(LuaContext& luaCtx) throw std::runtime_error(std::string("SpoofPacketAction: given packet len is too small")); } // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) - auto ret = std::shared_ptr(new SpoofAction(PacketBuffer(response.data(), response.data() + len))); + auto ret = dnsdist::actions::getSpoofAction(PacketBuffer(response.data(), response.data() + len)); return ret; }); - luaCtx.writeFunction("DropAction", []() { - return std::shared_ptr(new DropAction); - }); - - luaCtx.writeFunction("AllowAction", []() { - return std::shared_ptr(new AllowAction); - }); - - luaCtx.writeFunction("NoneAction", []() { - return std::shared_ptr(new NoneAction); - }); - - luaCtx.writeFunction("DelayAction", [](int msec) { - return std::shared_ptr(new DelayAction(msec)); - }); - - luaCtx.writeFunction("TCAction", []() { - return std::shared_ptr(new TCAction); - }); - - luaCtx.writeFunction("TCResponseAction", []() { - return std::shared_ptr(new TCResponseAction); - }); - - luaCtx.writeFunction("SetDisableValidationAction", []() { - return std::shared_ptr(new SetDisableValidationAction); - }); - - luaCtx.writeFunction("LogAction", [](boost::optional fname, boost::optional binary, boost::optional append, boost::optional buffered, boost::optional verboseOnly, boost::optional includeTimestamp) { - return std::shared_ptr(new LogAction(fname ? *fname : "", binary ? *binary : true, append ? *append : false, buffered ? *buffered : false, verboseOnly ? *verboseOnly : true, includeTimestamp ? *includeTimestamp : false)); - }); - - luaCtx.writeFunction("LogResponseAction", [](boost::optional fname, boost::optional append, boost::optional buffered, boost::optional verboseOnly, boost::optional includeTimestamp) { - return std::shared_ptr(new LogResponseAction(fname ? *fname : "", append ? *append : false, buffered ? *buffered : false, verboseOnly ? *verboseOnly : true, includeTimestamp ? *includeTimestamp : false)); - }); - luaCtx.writeFunction("LimitTTLResponseAction", [](uint32_t min, uint32_t max, boost::optional> types) { std::unordered_set capTypes; if (types) { @@ -2579,30 +213,30 @@ void setupLuaActions(LuaContext& luaCtx) capTypes.insert(QType(type)); } } - return std::shared_ptr(new LimitTTLResponseAction(min, max, capTypes)); + return dnsdist::actions::getLimitTTLResponseAction(min, max, capTypes); }); luaCtx.writeFunction("SetMinTTLResponseAction", [](uint32_t min) { - return std::shared_ptr(new LimitTTLResponseAction(min)); + return dnsdist::actions::getLimitTTLResponseAction(min); }); luaCtx.writeFunction("SetMaxTTLResponseAction", [](uint32_t max) { - return std::shared_ptr(new LimitTTLResponseAction(0, max)); + return dnsdist::actions::getLimitTTLResponseAction(0, max); }); luaCtx.writeFunction("SetMaxReturnedTTLAction", [](uint32_t max) { - return std::shared_ptr(new MaxReturnedTTLAction(max)); + return dnsdist::actions::getSetMaxReturnedTTLAction(max); }); luaCtx.writeFunction("SetMaxReturnedTTLResponseAction", [](uint32_t max) { - return std::shared_ptr(new MaxReturnedTTLResponseAction(max)); + return dnsdist::actions::getSetMaxReturnedTTLResponseAction(max); }); luaCtx.writeFunction("SetReducedTTLResponseAction", [](uint8_t percentage) { if (percentage > 100) { throw std::runtime_error(std::string("SetReducedTTLResponseAction takes a percentage between 0 and 100.")); } - return std::shared_ptr(new SetReducedTTLResponseAction(percentage)); + return dnsdist::actions::getSetReducedTTLResponseAction(percentage); }); luaCtx.writeFunction("ClearRecordTypesResponseAction", [](LuaTypeOrArrayOf types) { @@ -2616,64 +250,25 @@ void setupLuaActions(LuaContext& luaCtx) qtypes.insert(tpair.second); } } - return std::shared_ptr(new ClearRecordTypesResponseAction(std::move(qtypes))); + return dnsdist::actions::getClearRecordTypesResponseAction(std::move(qtypes)); }); luaCtx.writeFunction("RCodeAction", [](uint8_t rcode, boost::optional vars) { auto responseConfig = parseResponseConfig(vars); checkAllParametersConsumed("RCodeAction", vars); - auto ret = std::shared_ptr(new RCodeAction(rcode, responseConfig)); + auto ret = dnsdist::actions::getRCodeAction(rcode, responseConfig); return ret; }); luaCtx.writeFunction("ERCodeAction", [](uint8_t rcode, boost::optional vars) { auto responseConfig = parseResponseConfig(vars); checkAllParametersConsumed("ERCodeAction", vars); - auto ret = std::shared_ptr(new ERCodeAction(rcode, responseConfig)); + auto ret = dnsdist::actions::getERCodeAction(rcode, responseConfig); return ret; }); - luaCtx.writeFunction("SetSkipCacheAction", []() { - return std::shared_ptr(new SetSkipCacheAction); - }); - - luaCtx.writeFunction("SetSkipCacheResponseAction", []() { - return std::shared_ptr(new SetSkipCacheResponseAction); - }); - - luaCtx.writeFunction("SetTempFailureCacheTTLAction", [](int maxTTL) { - return std::shared_ptr(new SetTempFailureCacheTTLAction(maxTTL)); - }); - - luaCtx.writeFunction("DropResponseAction", []() { - return std::shared_ptr(new DropResponseAction); - }); - - luaCtx.writeFunction("AllowResponseAction", []() { - return std::shared_ptr(new AllowResponseAction); - }); - - luaCtx.writeFunction("DelayResponseAction", [](int msec) { - return std::shared_ptr(new DelayResponseAction(msec)); - }); - - luaCtx.writeFunction("LuaResponseAction", [](LuaResponseAction::func_t func) { - setLuaSideEffect(); - return std::shared_ptr(new LuaResponseAction(std::move(func))); - }); - - luaCtx.writeFunction("LuaFFIResponseAction", [](LuaFFIResponseAction::func_t func) { - setLuaSideEffect(); - return std::shared_ptr(new LuaFFIResponseAction(std::move(func))); - }); - - luaCtx.writeFunction("LuaFFIPerThreadResponseAction", [](const std::string& code) { - setLuaSideEffect(); - return std::shared_ptr(new LuaFFIPerThreadResponseAction(code)); - }); - #ifndef DISABLE_PROTOBUF - luaCtx.writeFunction("RemoteLogAction", [](std::shared_ptr logger, boost::optional> alterFunc, boost::optional> vars, boost::optional> metas) { + luaCtx.writeFunction("RemoteLogAction", [](std::shared_ptr logger, boost::optional alterFunc, boost::optional> vars, boost::optional> metas) { if (logger) { // avoids potentially-evaluated-expression warning with clang. RemoteLoggerInterface& remoteLoggerRef = *logger; @@ -2684,9 +279,11 @@ void setupLuaActions(LuaContext& luaCtx) } std::string tags; - RemoteLogActionConfiguration config; + dnsdist::actions::RemoteLogActionConfiguration config; config.logger = std::move(logger); - config.alterQueryFunc = std::move(alterFunc); + if (alterFunc) { + config.alterQueryFunc = std::move(*alterFunc); + } getOptionalValue(vars, "serverID", config.serverID); getOptionalValue(vars, "ipEncryptKey", config.ipEncryptKey); getOptionalValue(vars, "exportTags", tags); @@ -2710,10 +307,10 @@ void setupLuaActions(LuaContext& luaCtx) checkAllParametersConsumed("RemoteLogAction", vars); - return std::shared_ptr(new RemoteLogAction(config)); + return dnsdist::actions::getRemoteLogAction(config); }); - luaCtx.writeFunction("RemoteLogResponseAction", [](std::shared_ptr logger, boost::optional> alterFunc, boost::optional includeCNAME, boost::optional> vars, boost::optional> metas) { + luaCtx.writeFunction("RemoteLogResponseAction", [](std::shared_ptr logger, boost::optional alterFunc, boost::optional includeCNAME, boost::optional> vars, boost::optional> metas) { if (logger) { // avoids potentially-evaluated-expression warning with clang. RemoteLoggerInterface& remoteLoggerRef = *logger; @@ -2724,9 +321,11 @@ void setupLuaActions(LuaContext& luaCtx) } std::string tags; - RemoteLogActionConfiguration config; + dnsdist::actions::RemoteLogActionConfiguration config; config.logger = std::move(logger); - config.alterResponseFunc = std::move(alterFunc); + if (alterFunc) { + config.alterResponseFunc = std::move(*alterFunc); + } config.includeCNAME = includeCNAME ? *includeCNAME : false; getOptionalValue(vars, "serverID", config.serverID); getOptionalValue(vars, "ipEncryptKey", config.ipEncryptKey); @@ -2752,115 +351,78 @@ void setupLuaActions(LuaContext& luaCtx) checkAllParametersConsumed("RemoteLogResponseAction", vars); - return std::shared_ptr(new RemoteLogResponseAction(config)); + return dnsdist::actions::getRemoteLogResponseAction(config); }); - luaCtx.writeFunction("DnstapLogAction", [](const std::string& identity, std::shared_ptr logger, boost::optional> alterFunc) { - return std::shared_ptr(new DnstapLogAction(identity, logger, std::move(alterFunc))); + luaCtx.writeFunction("DnstapLogAction", [](const std::string& identity, std::shared_ptr logger, boost::optional alterFunc) { + return dnsdist::actions::getDnstapLogAction(identity, std::move(logger), alterFunc ? std::move(*alterFunc) : std::optional()); }); - luaCtx.writeFunction("DnstapLogResponseAction", [](const std::string& identity, std::shared_ptr logger, boost::optional> alterFunc) { - return std::shared_ptr(new DnstapLogResponseAction(identity, logger, std::move(alterFunc))); + luaCtx.writeFunction("DnstapLogResponseAction", [](const std::string& identity, std::shared_ptr logger, boost::optional alterFunc) { + return dnsdist::actions::getDnstapLogResponseAction(identity, std::move(logger), alterFunc ? std::move(*alterFunc) : std::optional()); }); #endif /* DISABLE_PROTOBUF */ luaCtx.writeFunction("TeeAction", [](const std::string& remote, boost::optional addECS, boost::optional local, boost::optional addProxyProtocol) { - boost::optional localAddr{boost::none}; + std::optional localAddr; if (local) { localAddr = ComboAddress(*local, 0); } - return std::shared_ptr(new TeeAction(ComboAddress(remote, 53), localAddr, addECS ? *addECS : false, addProxyProtocol ? *addProxyProtocol : false)); - }); - - luaCtx.writeFunction("SetECSPrefixLengthAction", [](uint16_t v4PrefixLength, uint16_t v6PrefixLength) { - return std::shared_ptr(new SetECSPrefixLengthAction(v4PrefixLength, v6PrefixLength)); - }); - - luaCtx.writeFunction("SetECSOverrideAction", [](bool ecsOverride) { - return std::shared_ptr(new SetECSOverrideAction(ecsOverride)); - }); - - luaCtx.writeFunction("SetDisableECSAction", []() { - return std::shared_ptr(new SetDisableECSAction()); + return dnsdist::actions::getTeeAction(ComboAddress(remote, 53), localAddr, addECS ? *addECS : false, addProxyProtocol ? *addProxyProtocol : false); }); luaCtx.writeFunction("SetECSAction", [](const std::string& v4Netmask, boost::optional v6Netmask) { if (v6Netmask) { - return std::shared_ptr(new SetECSAction(Netmask(v4Netmask), Netmask(*v6Netmask))); + return dnsdist::actions::getSetECSAction(v4Netmask, *v6Netmask); } - return std::shared_ptr(new SetECSAction(Netmask(v4Netmask))); - }); - -#ifdef HAVE_NET_SNMP - luaCtx.writeFunction("SNMPTrapAction", [](boost::optional reason) { - return std::shared_ptr(new SNMPTrapAction(reason ? *reason : "")); - }); - - luaCtx.writeFunction("SNMPTrapResponseAction", [](boost::optional reason) { - return std::shared_ptr(new SNMPTrapResponseAction(reason ? *reason : "")); - }); -#endif /* HAVE_NET_SNMP */ - - luaCtx.writeFunction("SetTagAction", [](const std::string& tag, const std::string& value) { - return std::shared_ptr(new SetTagAction(tag, value)); - }); - - luaCtx.writeFunction("SetTagResponseAction", [](const std::string& tag, const std::string& value) { - return std::shared_ptr(new SetTagResponseAction(tag, value)); + return dnsdist::actions::getSetECSAction(v4Netmask); }); luaCtx.writeFunction("ContinueAction", [](std::shared_ptr action) { - return std::shared_ptr(new ContinueAction(action)); + return dnsdist::actions::getContinueAction(std::move(action)); }); #ifdef HAVE_DNS_OVER_HTTPS luaCtx.writeFunction("HTTPStatusAction", [](uint16_t status, std::string body, boost::optional contentType, boost::optional vars) { auto responseConfig = parseResponseConfig(vars); checkAllParametersConsumed("HTTPStatusAction", vars); - auto ret = std::shared_ptr(new HTTPStatusAction(status, PacketBuffer(body.begin(), body.end()), contentType ? *contentType : "", responseConfig)); + auto ret = dnsdist::actions::getHTTPStatusAction(status, PacketBuffer(body.begin(), body.end()), contentType ? *contentType : "", responseConfig); return ret; }); #endif /* HAVE_DNS_OVER_HTTPS */ #if defined(HAVE_LMDB) || defined(HAVE_CDB) luaCtx.writeFunction("KeyValueStoreLookupAction", [](std::shared_ptr& kvs, std::shared_ptr& lookupKey, const std::string& destinationTag) { - return std::shared_ptr(new KeyValueStoreLookupAction(kvs, lookupKey, destinationTag)); + return dnsdist::actions::getKeyValueStoreLookupAction(kvs, lookupKey, destinationTag); }); luaCtx.writeFunction("KeyValueStoreRangeLookupAction", [](std::shared_ptr& kvs, std::shared_ptr& lookupKey, const std::string& destinationTag) { - return std::shared_ptr(new KeyValueStoreRangeLookupAction(kvs, lookupKey, destinationTag)); + return dnsdist::actions::getKeyValueStoreRangeLookupAction(kvs, lookupKey, destinationTag); }); #endif /* defined(HAVE_LMDB) || defined(HAVE_CDB) */ luaCtx.writeFunction("NegativeAndSOAAction", [](bool nxd, const std::string& zone, uint32_t ttl, const std::string& mname, const std::string& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum, boost::optional vars) { - auto responseConfig = parseResponseConfig(vars); bool soaInAuthoritySection = false; getOptionalValue(vars, "soaInAuthoritySection", soaInAuthoritySection); - NegativeAndSOAAction::SOAParams params{ + auto responseConfig = parseResponseConfig(vars); + checkAllParametersConsumed("NegativeAndSOAAction", vars); + dnsdist::actions::SOAParams params{ .serial = serial, .refresh = refresh, .retry = retry, .expire = expire, .minimum = minimum}; - checkAllParametersConsumed("NegativeAndSOAAction", vars); - auto ret = std::shared_ptr(new NegativeAndSOAAction(nxd, DNSName(zone), ttl, DNSName(mname), DNSName(rname), params, soaInAuthoritySection, responseConfig)); + auto ret = dnsdist::actions::getNegativeAndSOAAction(nxd, DNSName(zone), ttl, DNSName(mname), DNSName(rname), params, soaInAuthoritySection, responseConfig); return ret; }); luaCtx.writeFunction("SetProxyProtocolValuesAction", [](const std::vector>& values) { - return std::shared_ptr(new SetProxyProtocolValuesAction(values)); - }); - - luaCtx.writeFunction("SetAdditionalProxyProtocolValueAction", [](uint8_t type, const std::string& value) { - return std::shared_ptr(new SetAdditionalProxyProtocolValueAction(type, value)); - }); - - luaCtx.writeFunction("SetExtendedDNSErrorAction", [](uint16_t infoCode, boost::optional extraText) { - return std::shared_ptr(new SetExtendedDNSErrorAction(infoCode, extraText ? *extraText : "")); + return dnsdist::actions::getSetProxyProtocolValuesAction(values); }); - luaCtx.writeFunction("SetExtendedDNSErrorResponseAction", [](uint16_t infoCode, boost::optional extraText) { - return std::shared_ptr(new SetExtendedDNSErrorResponseAction(infoCode, extraText ? *extraText : "")); - }); +// NOLINTNEXTLINE(bugprone-suspicious-include) +#include "dnsdist-lua-actions-generated.cc" +// NOLINTNEXTLINE(bugprone-suspicious-include) +#include "dnsdist-lua-response-actions-generated.cc" } diff --git a/pdns/dnsdistdist/dnsdist-lua-rules.cc b/pdns/dnsdistdist/dnsdist-lua-rules.cc index dd86ec473c..8709a2eb15 100644 --- a/pdns/dnsdistdist/dnsdist-lua-rules.cc +++ b/pdns/dnsdistdist/dnsdist-lua-rules.cc @@ -21,6 +21,7 @@ */ #include "dnsdist.hh" #include "dnsdist-lua.hh" +#include "dnsdist-rules.hh" #include "dnsdist-rules-factory.hh" #include "dnsdist-rule-chains.hh" #include "dns_random.hh" @@ -211,7 +212,7 @@ static void moveRuleToTop(IdentifierTypeT chainIdentifier) if (rules.empty()) { return; } - //coverity[auto_causes_copy] + // coverity[auto_causes_copy] auto subject = *rules.rbegin(); rules.erase(std::prev(rules.end())); rules.insert(rules.begin(), subject); @@ -228,7 +229,7 @@ static void mvRule(IdentifierTypeT chainIdentifier, unsigned int from, unsigned g_outputBuffer = "Error: attempt to move rules from/to invalid index\n"; return; } - //coverity[auto_causes_copy] + // coverity[auto_causes_copy] auto subject = rules[from]; rules.erase(rules.begin() + from); if (destination > rules.size()) { @@ -341,6 +342,12 @@ std::shared_ptr qnameSuffixRule(const boost::variant(&names); return std::shared_ptr(new SuffixMatchNodeRule(smn, quiet ? *quiet : false)); } + +template +std::optional boostToStandardOptional(const boost::optional& boostOpt) +{ + return boostOpt ? *boostOpt : std::optional(); +} } // NOLINTNEXTLINE(readability-function-cognitive-complexity): this function declares Lua bindings, even with a good refactoring it will likely blow up the threshold @@ -462,43 +469,6 @@ void setupLuaRules(LuaContext& luaCtx) }); } - luaCtx.writeFunction("MaxQPSIPRule", [](unsigned int qps, boost::optional ipv4trunc, boost::optional ipv6trunc, boost::optional burst, boost::optional expiration, boost::optional cleanupDelay, boost::optional scanFraction, boost::optional shards) { - return std::shared_ptr(new MaxQPSIPRule(qps, (burst ? *burst : qps), (ipv4trunc ? *ipv4trunc : 32), (ipv6trunc ? *ipv6trunc : 64), (expiration ? *expiration : 300), (cleanupDelay ? *cleanupDelay : 60), (scanFraction ? *scanFraction : 10), (shards ? *shards : 10))); - }); - - luaCtx.writeFunction("MaxQPSRule", [](unsigned int qps, boost::optional burst) { - if (!burst) { - return std::shared_ptr(new MaxQPSRule(qps)); - } - return std::shared_ptr(new MaxQPSRule(qps, *burst)); - }); - - luaCtx.writeFunction("RegexRule", [](const std::string& str) { - return std::shared_ptr(new RegexRule(str)); - }); - -#ifdef HAVE_DNS_OVER_HTTPS - luaCtx.writeFunction("HTTPHeaderRule", [](const std::string& header, const std::string& regex) { - return std::shared_ptr(new HTTPHeaderRule(header, regex)); - }); - luaCtx.writeFunction("HTTPPathRule", [](const std::string& path) { - return std::shared_ptr(new HTTPPathRule(path)); - }); - luaCtx.writeFunction("HTTPPathRegexRule", [](const std::string& regex) { - return std::shared_ptr(new HTTPPathRegexRule(regex)); - }); -#endif - -#ifdef HAVE_RE2 - luaCtx.writeFunction("RE2Rule", [](const std::string& str) { - return std::shared_ptr(new RE2Rule(str)); - }); -#endif - - luaCtx.writeFunction("SNIRule", [](const std::string& name) { - return std::shared_ptr(new SNIRule(name)); - }); - luaCtx.writeFunction("SuffixMatchNodeRule", qnameSuffixRule); luaCtx.writeFunction("NetmaskGroupRule", [](const boost::variant> netmasks, boost::optional src, boost::optional quiet) { @@ -562,18 +532,6 @@ void setupLuaRules(LuaContext& luaCtx) g_outputBuffer = (boost::format("Had %d matches out of %d, %.1f qps, in %.1f us\n") % matches % times % (1000000 * (1.0 * times / udiff)) % udiff).str(); }); - luaCtx.writeFunction("AllRule", []() { - return std::shared_ptr(new AllRule()); - }); - - luaCtx.writeFunction("ProbaRule", [](double proba) { - return std::shared_ptr(new ProbaRule(proba)); - }); - - luaCtx.writeFunction("QNameRule", [](const std::string& qname) { - return std::shared_ptr(new QNameRule(DNSName(qname))); - }); - luaCtx.writeFunction("QNameSuffixRule", qnameSuffixRule); luaCtx.writeFunction("QTypeRule", [](boost::variant str) { @@ -596,65 +554,28 @@ void setupLuaRules(LuaContext& luaCtx) return std::shared_ptr(new QClassRule(cla)); }); - luaCtx.writeFunction("OpcodeRule", [](uint64_t code) { - checkParameterBound("OpcodeRule", code, std::numeric_limits::max()); - return std::shared_ptr(new OpcodeRule(code)); - }); - - luaCtx.writeFunction("AndRule", [](const LuaArray>& rules) { + luaCtx.writeFunction("AndRule", [](const LuaArray>& rulePairs) { + std::vector> rules; + rules.reserve(rulePairs.size()); + for (const auto& pair : rulePairs) { + rules.emplace_back(pair.second); + } return std::shared_ptr(new AndRule(rules)); }); - luaCtx.writeFunction("OrRule", [](const LuaArray>& rules) { + luaCtx.writeFunction("OrRule", [](const LuaArray>& rulePairs) { + std::vector> rules; + rules.reserve(rulePairs.size()); + for (const auto& pair : rulePairs) { + rules.emplace_back(pair.second); + } return std::shared_ptr(new OrRule(rules)); }); - luaCtx.writeFunction("DSTPortRule", [](uint64_t port) { - checkParameterBound("DSTPortRule", port, std::numeric_limits::max()); - return std::shared_ptr(new DSTPortRule(port)); - }); - - luaCtx.writeFunction("TCPRule", [](bool tcp) { - return std::shared_ptr(new TCPRule(tcp)); - }); - - luaCtx.writeFunction("DNSSECRule", []() { - return std::shared_ptr(new DNSSECRule()); - }); - luaCtx.writeFunction("NotRule", [](const std::shared_ptr& rule) { return std::shared_ptr(new NotRule(rule)); }); - luaCtx.writeFunction("RecordsCountRule", [](uint64_t section, uint64_t minCount, uint64_t maxCount) { - checkParameterBound("RecordsCountRule", section, std::numeric_limits::max()); - checkParameterBound("RecordsCountRule", minCount, std::numeric_limits::max()); - checkParameterBound("RecordsCountRule", maxCount, std::numeric_limits::max()); - return std::shared_ptr(new RecordsCountRule(section, minCount, maxCount)); - }); - - luaCtx.writeFunction("RecordsTypeCountRule", [](uint64_t section, uint64_t type, uint64_t minCount, uint64_t maxCount) { - checkParameterBound("RecordsTypeCountRule", section, std::numeric_limits::max()); - checkParameterBound("RecordsTypeCountRule", type, std::numeric_limits::max()); - checkParameterBound("RecordsTypeCountRule", minCount, std::numeric_limits::max()); - checkParameterBound("RecordsTypeCountRule", maxCount, std::numeric_limits::max()); - return std::shared_ptr(new RecordsTypeCountRule(section, type, minCount, maxCount)); - }); - - luaCtx.writeFunction("TrailingDataRule", []() { - return std::shared_ptr(new TrailingDataRule()); - }); - - luaCtx.writeFunction("QNameLabelsCountRule", [](uint64_t minLabelsCount, uint64_t maxLabelsCount) { - checkParameterBound("QNameLabelsCountRule", minLabelsCount, std::numeric_limits::max()); - checkParameterBound("QNameLabelsCountRule", maxLabelsCount, std::numeric_limits::max()); - return std::shared_ptr(new QNameLabelsCountRule(minLabelsCount, maxLabelsCount)); - }); - - luaCtx.writeFunction("QNameWireLengthRule", [](uint64_t min, uint64_t max) { - return std::shared_ptr(new QNameWireLengthRule(min, max)); - }); - luaCtx.writeFunction("RCodeRule", [](uint64_t rcode) { checkParameterBound("RCodeRule", rcode, std::numeric_limits::max()); return std::shared_ptr(new RCodeRule(rcode)); @@ -665,36 +586,10 @@ void setupLuaRules(LuaContext& luaCtx) return std::shared_ptr(new ERCodeRule(rcode)); }); - luaCtx.writeFunction("EDNSVersionRule", [](uint64_t version) { - checkParameterBound("EDNSVersionRule", version, std::numeric_limits::max()); - return std::shared_ptr(new EDNSVersionRule(version)); - }); - - luaCtx.writeFunction("EDNSOptionRule", [](uint64_t optcode) { - checkParameterBound("EDNSOptionRule", optcode, std::numeric_limits::max()); - return std::shared_ptr(new EDNSOptionRule(optcode)); - }); - - luaCtx.writeFunction("RDRule", []() { - return std::shared_ptr(new RDRule()); - }); - - luaCtx.writeFunction("TagRule", [](const std::string& tag, boost::optional value) { - return std::shared_ptr(new TagRule(tag, std::move(value))); - }); - luaCtx.writeFunction("TimedIPSetRule", []() { return std::make_shared(); }); - luaCtx.writeFunction("PoolAvailableRule", [](const std::string& poolname) { - return std::shared_ptr(new PoolAvailableRule(poolname)); - }); - - luaCtx.writeFunction("PoolOutstandingRule", [](const std::string& poolname, uint64_t limit) { - return std::shared_ptr(new PoolOutstandingRule(poolname, limit)); - }); - luaCtx.registerFunction::*)()>("clear", [](const std::shared_ptr& tisr) { tisr->clear(); }); @@ -714,6 +609,10 @@ void setupLuaRules(LuaContext& luaCtx) tisr->toString(); }); + luaCtx.writeFunction("QNameRule", [](const std::string& qname) { + return std::shared_ptr(dnsdist::selectors::getQNameSelector(DNSName(qname))); + }); + luaCtx.writeFunction("QNameSetRule", [](const DNSNameSet& names) { return std::shared_ptr(new QNameSetRule(names)); }); @@ -728,23 +627,6 @@ void setupLuaRules(LuaContext& luaCtx) }); #endif /* defined(HAVE_LMDB) || defined(HAVE_CDB) */ - luaCtx.writeFunction("LuaRule", [](const LuaRule::func_t& func) { - return std::shared_ptr(new LuaRule(func)); - }); - - luaCtx.writeFunction("LuaFFIRule", [](const LuaFFIRule::func_t& func) { - return std::shared_ptr(new LuaFFIRule(func)); - }); - - luaCtx.writeFunction("LuaFFIPerThreadRule", [](const std::string& code) { - return std::shared_ptr(new LuaFFIPerThreadRule(code)); - }); - - luaCtx.writeFunction("ProxyProtocolValueRule", [](uint8_t type, boost::optional value) { - return std::shared_ptr(new ProxyProtocolValueRule(type, std::move(value))); - }); - - luaCtx.writeFunction("PayloadSizeRule", [](const std::string& comparison, uint16_t size) { - return std::shared_ptr(new PayloadSizeRule(comparison, size)); - }); +// NOLINTNEXTLINE(bugprone-suspicious-include) +#include "dnsdist-lua-selectors-generated.cc" } diff --git a/pdns/dnsdistdist/dnsdist-rules-factory.hh b/pdns/dnsdistdist/dnsdist-rules-factory.hh index 470cbda7ad..08711f6fe8 100644 --- a/pdns/dnsdistdist/dnsdist-rules-factory.hh +++ b/pdns/dnsdistdist/dnsdist-rules-factory.hh @@ -21,27 +21,33 @@ */ #pragma once -#include "dnsdist-rules.hh" +#include +#include +#include +#include #include "cachecleaner.hh" #include "dnsdist-ecs.hh" #include "dnsdist-kvs.hh" #include "dnsdist-lua.hh" #include "dnsdist-lua-ffi.hh" +#include "dnsdist-rules.hh" #include "dolog.hh" #include "dnsparser.hh" #include "dns_random.hh" +#include "uuid-utils.hh" -#include -#include -#include -#include +namespace dnsdist::selectors +{ +using LuaSelectorFunction = std::function; +using LuaSelectorFFIFunction = std::function; +} class MaxQPSIPRule : public DNSRule { public: - MaxQPSIPRule(unsigned int qps, unsigned int burst, unsigned int ipv4trunc=32, unsigned int ipv6trunc=64, unsigned int expiration=300, unsigned int cleanupDelay=60, unsigned int scanFraction=10, size_t shardsCount=10): - d_shards(shardsCount), d_qps(qps), d_burst(burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc), d_cleanupDelay(cleanupDelay), d_expiration(expiration), d_scanFraction(scanFraction) + MaxQPSIPRule(unsigned int qps, unsigned int ipv4trunc = 32, unsigned int ipv6trunc = 64, unsigned int burst = 0, unsigned int expiration = 300, unsigned int cleanupDelay = 60, unsigned int scanFraction = 10, size_t shardsCount = 10) : + d_shards(shardsCount), d_qps(qps), d_burst(burst == 0 ? qps : burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc), d_cleanupDelay(cleanupDelay), d_expiration(expiration), d_scanFraction(scanFraction) { d_cleaningUp.clear(); gettime(&d_lastCleanup, true); @@ -54,7 +60,7 @@ public: } } - size_t cleanup(const struct timespec& cutOff, size_t* scannedCount=nullptr) const + size_t cleanup(const struct timespec& cutOff, size_t* scannedCount = nullptr) const { size_t removed = 0; if (scannedCount != nullptr) { @@ -63,7 +69,7 @@ public: for (auto& shard : d_shards) { auto limits = shard.lock(); - const size_t toLook = std::round((1.0 * limits->size()) / d_scanFraction)+ 1; + const size_t toLook = std::round((1.0 * limits->size()) / d_scanFraction) + 1; size_t lookedAt = 0; auto& sequence = limits->get(); @@ -121,7 +127,7 @@ public: cleanupIfNeeded(dq->getQueryRealTime()); ComboAddress zeroport(dq->ids.origRemote); - zeroport.sin4.sin_port=0; + zeroport.sin4.sin_port = 0; zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc); auto hash = ComboAddress::addressOnlyHash()(zeroport); auto& shard = d_shards[hash % d_shards.size()]; @@ -140,7 +146,7 @@ public: string toString() const override { - return "IP (/"+std::to_string(d_ipv4trunc)+", /"+std::to_string(d_ipv6trunc)+") match for QPS over " + std::to_string(d_qps) + " burst "+ std::to_string(d_burst); + return "IP (/" + std::to_string(d_ipv4trunc) + ", /" + std::to_string(d_ipv6trunc) + ") match for QPS over " + std::to_string(d_qps) + " burst " + std::to_string(d_burst); } size_t getEntriesCount() const @@ -158,24 +164,27 @@ public: } private: - struct HashedTag {}; - struct SequencedTag {}; + struct HashedTag + { + }; + struct SequencedTag + { + }; struct Entry { - Entry(const ComboAddress& addr, BasicQPSLimiter&& limiter): d_limiter(limiter), d_addr(addr) + Entry(const ComboAddress& addr, BasicQPSLimiter&& limiter) : + d_limiter(limiter), d_addr(addr) { } mutable BasicQPSLimiter d_limiter; ComboAddress d_addr; }; - typedef multi_index_container< + using qpsContainer_t = multi_index_container< Entry, - indexed_by < - hashed_unique, member, ComboAddress::addressOnlyHash >, - sequenced > - > - > qpsContainer_t; + indexed_by< + hashed_unique, member, ComboAddress::addressOnlyHash>, + sequenced>>>; mutable std::vector> d_shards; mutable struct timespec d_lastCleanup; @@ -187,14 +196,15 @@ private: class MaxQPSRule : public DNSRule { public: - MaxQPSRule(unsigned int qps) - : d_qps(qps, qps) - {} - - MaxQPSRule(unsigned int qps, unsigned int burst) - : d_qps(qps, burst) - {} + MaxQPSRule(unsigned int qps) : + d_qps(qps, qps) + { + } + MaxQPSRule(unsigned int qps, unsigned int burst) : + d_qps(qps, burst > 0 ? burst : qps) + { + } bool matches(const DNSQuestion* qd) const override { @@ -206,31 +216,23 @@ public: return "Max " + std::to_string(d_qps.getRate()) + " qps"; } - private: mutable QPSLimiter d_qps; }; -class NMGRule : public DNSRule +class NetmaskGroupRule : public DNSRule { public: - NMGRule(const NetmaskGroup& nmg) : d_nmg(nmg) {} -protected: - NetmaskGroup d_nmg; -}; - -class NetmaskGroupRule : public NMGRule -{ -public: - NetmaskGroupRule(const NetmaskGroup& nmg, bool src, bool quiet = false) : NMGRule(nmg) + NetmaskGroupRule(const NetmaskGroup& nmg, bool src, bool quiet = false) : + d_nmg(nmg) { - d_src = src; - d_quiet = quiet; + d_src = src; + d_quiet = quiet; } bool matches(const DNSQuestion* dq) const override { - if(!d_src) { - return d_nmg.match(dq->ids.origDest); + if (!d_src) { + return d_nmg.match(dq->ids.origDest); } return d_nmg.match(dq->ids.origRemote); } @@ -238,15 +240,17 @@ public: string toString() const override { string ret = "Src: "; - if(!d_src) { - ret = "Dst: "; + if (!d_src) { + ret = "Dst: "; } if (d_quiet) { return ret + "in-group"; } return ret + d_nmg.toString(); } + private: + NetmaskGroup d_nmg; bool d_src; bool d_quiet; }; @@ -254,15 +258,16 @@ private: class TimedIPSetRule : public DNSRule, boost::noncopyable { private: - struct IPv6 { + struct IPv6 + { IPv6(const ComboAddress& ca) { - static_assert(sizeof(*this)==16, "IPv6 struct has wrong size"); + static_assert(sizeof(*this) == 16, "IPv6 struct has wrong size"); memcpy((char*)this, ca.sin6.sin6_addr.s6_addr, 16); } bool operator==(const IPv6& rhs) const { - return a==rhs.a && b==rhs.b; + return a == rhs.a && b == rhs.b; } uint64_t a, b; }; @@ -283,7 +288,8 @@ public: return false; } return time(nullptr) < fnd->second; - } else { + } + else { auto ip6s = d_ip6s.read_lock(); auto fnd = ip6s->find({dq->ids.origRemote}); if (fnd == ip6s->end()) { @@ -332,29 +338,27 @@ public: time_t now = time(nullptr); { auto ip4s = d_ip4s.write_lock(); - for (auto iter = ip4s->begin(); iter != ip4s->end(); ) { - if (iter->second < now) { - iter = ip4s->erase(iter); + for (auto iter = ip4s->begin(); iter != ip4s->end();) { + if (iter->second < now) { + iter = ip4s->erase(iter); } - else { - ++iter; + else { + ++iter; } } } { auto ip6s = d_ip6s.write_lock(); - for (auto iter = ip6s->begin(); iter != ip6s->end(); ) { - if (iter->second < now) { - iter = ip6s->erase(iter); + for (auto iter = ip6s->begin(); iter != ip6s->end();) { + if (iter->second < now) { + iter = ip6s->erase(iter); } - else { - ++iter; + else { + ++iter; } } - } - } string toString() const override @@ -374,23 +378,23 @@ public: } } - return "Src: "+std::to_string(count)+" ips"; + return "Src: " + std::to_string(count) + " ips"; } + private: struct IPv6Hash { std::size_t operator()(const IPv6& ip) const { - auto ah=std::hash{}(ip.a); - auto bh=std::hash{}(ip.b); - return ah & (bh<<1); + auto ah = std::hash{}(ip.a); + auto bh = std::hash{}(ip.b); + return ah & (bh << 1); } }; mutable SharedLockGuarded> d_ip6s; mutable SharedLockGuarded> d_ip4s; }; - class AllRule : public DNSRule { public: @@ -404,20 +408,17 @@ public: { return "All"; } - }; - class DNSSECRule : public DNSRule { public: DNSSECRule() { - } bool matches(const DNSQuestion* dq) const override { - return dq->getHeader()->cd || (dnsdist::getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default.. + return dq->getHeader()->cd || (dnsdist::getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default.. } string toString() const override @@ -429,11 +430,9 @@ public: class AndRule : public DNSRule { public: - AndRule(const std::vector > >& rules) + AndRule(const std::vector>& rules) : + d_rules(rules) { - for (const auto& r : rules) { - d_rules.push_back(r.second); - } } bool matches(const DNSQuestion* dq) const override @@ -451,30 +450,28 @@ public: string ret; for (const auto& rule : d_rules) { if (!ret.empty()) { - ret+= " && "; + ret += " && "; } - ret += "("+ rule->toString()+")"; + ret += "(" + rule->toString() + ")"; } return ret; } + private: - std::vector > d_rules; + std::vector> d_rules; }; - class OrRule : public DNSRule { public: - OrRule(const std::vector > >& rules) + OrRule(const std::vector>& rules) : + d_rules(rules) { - for (const auto& r : rules) { - d_rules.push_back(r.second); - } } bool matches(const DNSQuestion* dq) const override { - for (const auto& rule: d_rules) { + for (const auto& rule : d_rules) { if (rule->matches(dq)) { return true; } @@ -487,23 +484,23 @@ public: string ret; for (const auto& rule : d_rules) { if (!ret.empty()) { - ret+= " || "; + ret += " || "; } - ret += "("+ rule->toString()+")"; + ret += "(" + rule->toString() + ")"; } return ret; } + private: - std::vector > d_rules; + std::vector> d_rules; }; - class RegexRule : public DNSRule { public: - RegexRule(const std::string& regex) : d_regex(regex), d_visual(regex) + RegexRule(const std::string& regex) : + d_regex(regex), d_visual(regex) { - } bool matches(const DNSQuestion* dq) const override { @@ -512,21 +509,22 @@ public: string toString() const override { - return "Regex: "+d_visual; + return "Regex: " + d_visual; } + private: Regex d_regex; string d_visual; }; -#ifdef HAVE_RE2 +#if defined(HAVE_RE2) #include class RE2Rule : public DNSRule { public: - RE2Rule(const std::string& re2) : d_re2(re2, RE2::Latin1), d_visual(re2) + RE2Rule(const std::string& re2) : + d_re2(re2, RE2::Latin1), d_visual(re2) { - } bool matches(const DNSQuestion* dq) const override { @@ -535,21 +533,40 @@ public: string toString() const override { - return "RE2 match: "+d_visual; + return "RE2 match: " + d_visual; } + private: RE2 d_re2; string d_visual; }; -#endif +#else /* HAVE_RE2 */ +class RE2Rule : public DNSRule +{ +public: + RE2Rule(const std::string& re2) + { + throw std::runtime_error("RE2 support is disabled"); + } + bool matches(const DNSQuestion* dq) const override + { + return false; + } + + string toString() const override + { + return "Unsupported RE2"; + } +}; +#endif /* HAVE_RE2 */ -#ifdef HAVE_DNS_OVER_HTTPS class HTTPHeaderRule : public DNSRule { public: HTTPHeaderRule(const std::string& header, const std::string& regex); - bool matches(const DNSQuestion* dq) const override; + bool matches(const DNSQuestion* dnsQuestion) const override; string toString() const override; + private: string d_header; Regex d_regex; @@ -560,8 +577,9 @@ class HTTPPathRule : public DNSRule { public: HTTPPathRule(std::string path); - bool matches(const DNSQuestion* dq) const override; + bool matches(const DNSQuestion* dnsQuestion) const override; string toString() const override; + private: string d_path; }; @@ -570,18 +588,19 @@ class HTTPPathRegexRule : public DNSRule { public: HTTPPathRegexRule(const std::string& regex); - bool matches(const DNSQuestion* dq) const override; + bool matches(const DNSQuestion* dnsQuestion) const override; string toString() const override; + private: Regex d_regex; std::string d_visual; }; -#endif class SNIRule : public DNSRule { public: - SNIRule(const std::string& name) : d_sni(name) + SNIRule(const std::string& name) : + d_sni(name) { } bool matches(const DNSQuestion* dq) const override @@ -592,6 +611,7 @@ public: { return "SNI == " + d_sni; } + private: std::string d_sni; }; @@ -599,7 +619,8 @@ private: class SuffixMatchNodeRule : public DNSRule { public: - SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet=false) : d_smn(smn), d_quiet(quiet) + SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet = false) : + d_smn(smn), d_quiet(quiet) { } bool matches(const DNSQuestion* dq) const override @@ -608,11 +629,12 @@ public: } string toString() const override { - if(d_quiet) + if (d_quiet) return "qname==in-set"; else - return "qname in "+d_smn.toString(); + return "qname in " + d_smn.toString(); } + private: SuffixMatchNode d_smn; bool d_quiet; @@ -621,43 +643,51 @@ private: class QNameRule : public DNSRule { public: - QNameRule(const DNSName& qname) : d_qname(qname) + QNameRule(const DNSName& qname) : + d_qname(qname) { } bool matches(const DNSQuestion* dq) const override { - return d_qname==dq->ids.qname; + return d_qname == dq->ids.qname; } string toString() const override { - return "qname=="+d_qname.toString(); + return "qname==" + d_qname.toString(); } + private: DNSName d_qname; }; -class QNameSetRule : public DNSRule { +class QNameSetRule : public DNSRule +{ public: - QNameSetRule(const DNSNameSet& names) : qname_idx(names) {} + QNameSetRule(const DNSNameSet& names) : + qname_idx(names) {} - bool matches(const DNSQuestion* dq) const override { - return qname_idx.find(dq->ids.qname) != qname_idx.end(); - } + bool matches(const DNSQuestion* dq) const override + { + return qname_idx.find(dq->ids.qname) != qname_idx.end(); + } + + string toString() const override + { + std::stringstream ss; + ss << "qname in DNSNameSet(" << qname_idx.size() << " FQDNs)"; + return ss.str(); + } - string toString() const override { - std::stringstream ss; - ss << "qname in DNSNameSet(" << qname_idx.size() << " FQDNs)"; - return ss.str(); - } private: - DNSNameSet qname_idx; + DNSNameSet qname_idx; }; class QTypeRule : public DNSRule { public: - QTypeRule(uint16_t qtype) : d_qtype(qtype) + QTypeRule(uint16_t qtype) : + d_qtype(qtype) { } bool matches(const DNSQuestion* dq) const override @@ -667,8 +697,9 @@ public: string toString() const override { QType qt(d_qtype); - return "qtype=="+qt.toString(); + return "qtype==" + qt.toString(); } + private: uint16_t d_qtype; }; @@ -676,7 +707,8 @@ private: class QClassRule : public DNSRule { public: - QClassRule(uint16_t qclass) : d_qclass(qclass) + QClassRule(uint16_t qclass) : + d_qclass(qclass) { } bool matches(const DNSQuestion* dq) const override @@ -685,8 +717,9 @@ public: } string toString() const override { - return "qclass=="+std::to_string(d_qclass); + return "qclass==" + std::to_string(d_qclass); } + private: uint16_t d_qclass; }; @@ -694,7 +727,8 @@ private: class OpcodeRule : public DNSRule { public: - OpcodeRule(uint8_t opcode) : d_opcode(opcode) + OpcodeRule(uint8_t opcode) : + d_opcode(opcode) { } bool matches(const DNSQuestion* dq) const override @@ -703,8 +737,9 @@ public: } string toString() const override { - return "opcode=="+std::to_string(d_opcode); + return "opcode==" + std::to_string(d_opcode); } + private: uint8_t d_opcode; }; @@ -712,7 +747,8 @@ private: class DSTPortRule : public DNSRule { public: - DSTPortRule(uint16_t port) : d_port(port) + DSTPortRule(uint16_t port) : + d_port(port) { } bool matches(const DNSQuestion* dq) const override @@ -721,8 +757,9 @@ public: } string toString() const override { - return "dst port=="+std::to_string(d_port); + return "dst port==" + std::to_string(d_port); } + private: uint16_t d_port; }; @@ -730,7 +767,8 @@ private: class TCPRule : public DNSRule { public: - TCPRule(bool tcp): d_tcp(tcp) + TCPRule(bool tcp) : + d_tcp(tcp) { } bool matches(const DNSQuestion* dq) const override @@ -741,15 +779,16 @@ public: { return (d_tcp ? "TCP" : "UDP"); } + private: bool d_tcp; }; - class NotRule : public DNSRule { public: - NotRule(const std::shared_ptr& rule): d_rule(rule) + NotRule(const std::shared_ptr& rule) : + d_rule(rule) { } bool matches(const DNSQuestion* dq) const override @@ -758,8 +797,9 @@ public: } string toString() const override { - return "!("+ d_rule->toString()+")"; + return "!(" + d_rule->toString() + ")"; } + private: std::shared_ptr d_rule; }; @@ -767,13 +807,14 @@ private: class RecordsCountRule : public DNSRule { public: - RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount): d_minCount(minCount), d_maxCount(maxCount), d_section(section) + RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount) : + d_minCount(minCount), d_maxCount(maxCount), d_section(section) { } bool matches(const DNSQuestion* dq) const override { uint16_t count = 0; - switch(d_section) { + switch (d_section) { case 0: count = ntohs(dq->getHeader()->qdcount); break; @@ -792,7 +833,7 @@ public: string toString() const override { string section; - switch(d_section) { + switch (d_section) { case 0: section = "QD"; break; @@ -806,8 +847,9 @@ public: section = "AR"; break; } - return std::to_string(d_minCount) + " <= records in " + section + " <= "+ std::to_string(d_maxCount); + return std::to_string(d_minCount) + " <= records in " + section + " <= " + std::to_string(d_maxCount); } + private: uint16_t d_minCount; uint16_t d_maxCount; @@ -817,13 +859,14 @@ private: class RecordsTypeCountRule : public DNSRule { public: - RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount): d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section) + RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount) : + d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section) { } bool matches(const DNSQuestion* dq) const override { uint16_t count = 0; - switch(d_section) { + switch (d_section) { case 0: count = ntohs(dq->getHeader()->qdcount); break; @@ -846,7 +889,7 @@ public: string toString() const override { string section; - switch(d_section) { + switch (d_section) { case 0: section = "QD"; break; @@ -860,8 +903,9 @@ public: section = "AR"; break; } - return std::to_string(d_minCount) + " <= " + QType(d_type).toString() + " records in " + section + " <= "+ std::to_string(d_maxCount); + return std::to_string(d_minCount) + " <= " + QType(d_type).toString() + " records in " + section + " <= " + std::to_string(d_maxCount); } + private: uint16_t d_type; uint16_t d_minCount; @@ -889,7 +933,8 @@ public: class QNameLabelsCountRule : public DNSRule { public: - QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount) + QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount) : + d_min(minLabelsCount), d_max(maxLabelsCount) { } bool matches(const DNSQuestion* dq) const override @@ -901,6 +946,7 @@ public: { return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max); } + private: unsigned int d_min; unsigned int d_max; @@ -909,7 +955,8 @@ private: class QNameWireLengthRule : public DNSRule { public: - QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max) + QNameWireLengthRule(size_t min, size_t max) : + d_min(min), d_max(max) { } bool matches(const DNSQuestion* dq) const override @@ -921,6 +968,7 @@ public: { return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max); } + private: size_t d_min; size_t d_max; @@ -929,7 +977,8 @@ private: class RCodeRule : public DNSRule { public: - RCodeRule(uint8_t rcode) : d_rcode(rcode) + RCodeRule(uint8_t rcode) : + d_rcode(rcode) { } bool matches(const DNSQuestion* dq) const override @@ -938,8 +987,9 @@ public: } string toString() const override { - return "rcode=="+RCode::to_s(d_rcode); + return "rcode==" + RCode::to_s(d_rcode); } + private: uint8_t d_rcode; }; @@ -947,7 +997,8 @@ private: class ERCodeRule : public DNSRule { public: - ERCodeRule(uint8_t rcode) : d_rcode(rcode & 0xF), d_extrcode(rcode >> 4) + ERCodeRule(uint8_t rcode) : + d_rcode(rcode & 0xF), d_extrcode(rcode >> 4) { } bool matches(const DNSQuestion* dq) const override @@ -966,17 +1017,19 @@ public: } string toString() const override { - return "ercode=="+ERCode::to_s(d_rcode | (d_extrcode << 4)); + return "ercode==" + ERCode::to_s(d_rcode | (d_extrcode << 4)); } + private: - uint8_t d_rcode; // plain DNS Rcode - uint8_t d_extrcode; // upper bits in EDNS0 record + uint8_t d_rcode; // plain DNS Rcode + uint8_t d_extrcode; // upper bits in EDNS0 record }; class EDNSVersionRule : public DNSRule { public: - EDNSVersionRule(uint8_t version) : d_version(version) + EDNSVersionRule(uint8_t version) : + d_version(version) { } bool matches(const DNSQuestion* dq) const override @@ -990,8 +1043,9 @@ public: } string toString() const override { - return "ednsversion>"+std::to_string(d_version); + return "ednsversion>" + std::to_string(d_version); } + private: uint8_t d_version; }; @@ -999,7 +1053,8 @@ private: class EDNSOptionRule : public DNSRule { public: - EDNSOptionRule(uint16_t optcode) : d_optcode(optcode) + EDNSOptionRule(uint16_t optcode) : + d_optcode(optcode) { } bool matches(const DNSQuestion* dq) const override @@ -1026,8 +1081,9 @@ public: } string toString() const override { - return "ednsoptcode=="+std::to_string(d_optcode); + return "ednsoptcode==" + std::to_string(d_optcode); } + private: uint16_t d_optcode; }; @@ -1051,20 +1107,22 @@ public: class ProbaRule : public DNSRule { public: - ProbaRule(double proba) : d_proba(proba) + ProbaRule(double proba) : + d_proba(proba) { } bool matches(const DNSQuestion* dq) const override { - if(d_proba == 1.0) + if (d_proba == 1.0) return true; - double rnd = 1.0*dns_random_uint32() / UINT32_MAX; + double rnd = 1.0 * dns_random_uint32() / UINT32_MAX; return rnd > (1.0 - d_proba); } string toString() const override { return "match with prob. " + (boost::format("%0.2f") % d_proba).str(); } + private: double d_proba; }; @@ -1072,7 +1130,8 @@ private: class TagRule : public DNSRule { public: - TagRule(const std::string& tag, boost::optional value) : d_value(std::move(value)), d_tag(tag) + TagRule(const std::string& tag, boost::optional value) : + d_value(std::move(value)), d_tag(tag) { } bool matches(const DNSQuestion* dq) const override @@ -1086,7 +1145,7 @@ public: return false; } - if (!d_value) { + if (!d_value || d_value->empty()) { return true; } @@ -1106,7 +1165,8 @@ private: class PoolAvailableRule : public DNSRule { public: - PoolAvailableRule(const std::string& poolname) : d_poolname(poolname) + PoolAvailableRule(const std::string& poolname) : + d_poolname(poolname) { } @@ -1119,6 +1179,7 @@ public: { return "pool '" + d_poolname + "' is available"; } + private: std::string d_poolname; }; @@ -1126,7 +1187,8 @@ private: class PoolOutstandingRule : public DNSRule { public: - PoolOutstandingRule(const std::string& poolname, const size_t limit) : d_poolname(poolname), d_limit(limit) + PoolOutstandingRule(const std::string& poolname, const size_t limit) : + d_poolname(poolname), d_limit(limit) { } @@ -1139,15 +1201,17 @@ public: { return "pool '" + d_poolname + "' outstanding > " + std::to_string(d_limit); } + private: std::string d_poolname; size_t d_limit; }; -class KeyValueStoreLookupRule: public DNSRule +class KeyValueStoreLookupRule : public DNSRule { public: - KeyValueStoreLookupRule(std::shared_ptr& kvs, std::shared_ptr& lookupKey): d_kvs(kvs), d_key(lookupKey) + KeyValueStoreLookupRule(const std::shared_ptr& kvs, const std::shared_ptr& lookupKey) : + d_kvs(kvs), d_key(lookupKey) { } @@ -1173,10 +1237,11 @@ private: std::shared_ptr d_key; }; -class KeyValueStoreRangeLookupRule: public DNSRule +class KeyValueStoreRangeLookupRule : public DNSRule { public: - KeyValueStoreRangeLookupRule(std::shared_ptr& kvs, std::shared_ptr& lookupKey): d_kvs(kvs), d_key(lookupKey) + KeyValueStoreRangeLookupRule(const std::shared_ptr& kvs, const std::shared_ptr& lookupKey) : + d_kvs(kvs), d_key(lookupKey) { } @@ -1206,8 +1271,8 @@ private: class LuaRule : public DNSRule { public: - typedef std::function func_t; - LuaRule(const func_t& func): d_func(func) + LuaRule(const dnsdist::selectors::LuaSelectorFunction& func) : + d_func(func) {} bool matches(const DNSQuestion* dq) const override @@ -1215,9 +1280,11 @@ public: try { auto lock = g_lua.lock(); return d_func(dq); - } catch (const std::exception &e) { + } + catch (const std::exception& e) { warnlog("LuaRule failed inside Lua: %s", e.what()); - } catch (...) { + } + catch (...) { warnlog("LuaRule failed inside Lua: [unknown exception]"); } return false; @@ -1227,15 +1294,16 @@ public: { return "Lua script"; } + private: - func_t d_func; + dnsdist::selectors::LuaSelectorFunction d_func; }; class LuaFFIRule : public DNSRule { public: - typedef std::function func_t; - LuaFFIRule(const func_t& func): d_func(func) + LuaFFIRule(const dnsdist::selectors::LuaSelectorFFIFunction& func) : + d_func(func) {} bool matches(const DNSQuestion* dq) const override @@ -1244,9 +1312,11 @@ public: try { auto lock = g_lua.lock(); return d_func(&dqffi); - } catch (const std::exception &e) { + } + catch (const std::exception& e) { warnlog("LuaFFIRule failed inside Lua: %s", e.what()); - } catch (...) { + } + catch (...) { warnlog("LuaFFIRule failed inside Lua: [unknown exception]"); } return false; @@ -1256,16 +1326,16 @@ public: { return "Lua FFI script"; } + private: - func_t d_func; + dnsdist::selectors::LuaSelectorFFIFunction d_func; }; class LuaFFIPerThreadRule : public DNSRule { public: - typedef std::function func_t; - - LuaFFIPerThreadRule(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++) + LuaFFIPerThreadRule(const std::string& code) : + d_functionCode(code), d_functionID(s_functionsCounter++) { } @@ -1278,7 +1348,7 @@ public: /* mark the state as initialized first so if there is a syntax error we only try to execute the code once */ state.d_initialized = true; - state.d_func = state.d_luaContext.executeCode(d_functionCode); + state.d_func = state.d_luaContext.executeCode(d_functionCode); } if (!state.d_func) { @@ -1289,7 +1359,7 @@ public: dnsdist_ffi_dnsquestion_t dqffi(const_cast(dq)); return state.d_func(&dqffi); } - catch (const std::exception &e) { + catch (const std::exception& e) { warnlog("LuaFFIPerthreadRule failed inside Lua: %s", e.what()); } catch (...) { @@ -1302,11 +1372,12 @@ public: { return "Lua FFI per-thread script"; } + private: struct PerThreadState { LuaContext d_luaContext; - func_t d_func; + dnsdist::selectors::LuaSelectorFFIFunction d_func; bool d_initialized{false}; }; @@ -1319,7 +1390,8 @@ private: class ProxyProtocolValueRule : public DNSRule { public: - ProxyProtocolValueRule(uint8_t type, boost::optional value): d_value(std::move(value)), d_type(type) + ProxyProtocolValueRule(uint8_t type, boost::optional value) : + d_value(std::move(value)), d_type(type) { } @@ -1330,7 +1402,7 @@ public: } for (const auto& entry : *dq->proxyProtocolValues) { - if (entry.type == d_type && (!d_value || entry.content == *d_value)) { + if (entry.type == d_type && (!d_value || d_value->empty() || entry.content == *d_value)) { return true; } } @@ -1353,9 +1425,18 @@ private: class PayloadSizeRule : public DNSRule { - enum class Comparisons : uint8_t { equal, greater, greaterOrEqual, smaller, smallerOrEqual }; + enum class Comparisons : uint8_t + { + equal, + greater, + greaterOrEqual, + smaller, + smallerOrEqual + }; + public: - PayloadSizeRule(const std::string& comparison, uint16_t size): d_size(size) + PayloadSizeRule(const std::string& comparison, uint16_t size) : + d_size(size) { if (comparison == "equal") { d_comparison = Comparisons::equal; @@ -1400,12 +1481,11 @@ public: string toString() const override { static const std::array comparisonStr{ - "equal to" , + "equal to", "greater than", "equal to or greater than", "smaller than", - "equal to or smaller than" - }; + "equal to or smaller than"}; return "payload size is " + comparisonStr.at(static_cast(d_comparison)) + " " + std::to_string(d_size); } @@ -1413,3 +1493,20 @@ private: uint16_t d_size; Comparisons d_comparison; }; + +namespace dnsdist::selectors +{ +std::shared_ptr getAndSelector(const std::vector>& rules); +std::shared_ptr getOrSelector(const std::vector>& rules); +std::shared_ptr getNotSelector(const std::shared_ptr& rule); +std::shared_ptr getQNameSelector(const DNSName& qname); +std::shared_ptr getQNameSetSelector(const DNSNameSet& qnames); +std::shared_ptr getQNameSuffixSelector(const SuffixMatchNode& suffixes, bool quiet); +std::shared_ptr getQTypeSelector(const std::string& qtypeStr, uint16_t qtypeCode); +std::shared_ptr getQClassSelector(const std::string& qclassStr, uint16_t qclassCode); +std::shared_ptr getNetmaskGroupSelector(const NetmaskGroup& nmg, bool source, bool quiet); +std::shared_ptr getKeyValueStoreLookupSelector(const std::shared_ptr& kvs, const std::shared_ptr& lookupKey); +std::shared_ptr getKeyValueStoreRangeLookupSelector(const std::shared_ptr& kvs, const std::shared_ptr& lookupKey); + +#include "dnsdist-selectors-factory-generated.hh" +} diff --git a/pdns/dnsdistdist/dnsdist-rules.cc b/pdns/dnsdistdist/dnsdist-rules.cc index fe716a7d1d..e4904fdca1 100644 --- a/pdns/dnsdistdist/dnsdist-rules.cc +++ b/pdns/dnsdistdist/dnsdist-rules.cc @@ -123,3 +123,75 @@ string HTTPPathRegexRule::toString() const { return d_visual; } + +namespace dnsdist::selectors +{ +std::shared_ptr getQClassSelector(const std::string& qclassStr, uint16_t qclassCode) +{ + QClass qclass(qclassCode); + if (!qclassStr.empty()) { + qclass = QClass(std::string(qclassStr)); + } + + return std::make_shared(qclass); +} + +std::shared_ptr getQTypeSelector(const std::string& qtypeStr, uint16_t qtypeCode) +{ + QType qtype(qtypeCode); + if (!qtypeStr.empty()) { + qtype = std::string(qtypeStr); + } + + return std::make_shared(qtype); +} + +std::shared_ptr getQNameSuffixSelector(const SuffixMatchNode& suffixes, bool quiet) +{ + return std::make_shared(suffixes, quiet); +} + +std::shared_ptr getQNameSetSelector(const DNSNameSet& qnames) +{ + return std::make_shared(qnames); +} + +std::shared_ptr getQNameSelector(const DNSName& qname) +{ + return std::make_shared(qname); +} + +std::shared_ptr getNetmaskGroupSelector(const NetmaskGroup& nmg, bool source, bool quiet) +{ + return std::make_shared(nmg, source, quiet); +} + +std::shared_ptr getKeyValueStoreLookupSelector(const std::shared_ptr& kvs, const std::shared_ptr& lookupKey) +{ + return std::make_shared(kvs, lookupKey); +} + +std::shared_ptr getKeyValueStoreRangeLookupSelector(const std::shared_ptr& kvs, const std::shared_ptr& lookupKey) +{ + return std::make_shared(kvs, lookupKey); +} + +std::shared_ptr getAndSelector(const std::vector>& rules) +{ + return std::make_shared(rules); +} + +std::shared_ptr getOrSelector(const std::vector>& rules) +{ + return std::make_shared(rules); +} + +std::shared_ptr getNotSelector(const std::shared_ptr& rule) +{ + return std::make_shared(rule); +} + +// NOLINTNEXTLINE(bugprone-suspicious-include) +#include "dnsdist-selectors-factory-generated.cc" + +}