From: Aki Tuomi Date: Mon, 22 Feb 2021 12:07:28 +0000 (+0200) Subject: dnsdist: Warn on unsupported parameters X-Git-Tag: dnsdist-1.8.0-rc1~43^2~6 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=45f86d734a8f0a9c81a5c22c7a717ea9f009ad6f;p=thirdparty%2Fpdns.git dnsdist: Warn on unsupported parameters --- diff --git a/pdns/dnsdist-lua-actions.cc b/pdns/dnsdist-lua-actions.cc index fb0695c376..245032fdfb 100644 --- a/pdns/dnsdist-lua-actions.cc +++ b/pdns/dnsdist-lua-actions.cc @@ -31,7 +31,6 @@ #include "dnsdist-kvs.hh" #include "dnsdist-svc.hh" -#include "dolog.hh" #include "dnstap.hh" #include "dnswriter.hh" #include "ednsoptions.hh" @@ -2100,6 +2099,7 @@ static void addAction(GlobalStateHolder > *someRuleActions, const luad boost::uuids::uuid uuid; uint64_t creationOrder; parseRuleParams(params, uuid, name, creationOrder); + checkAllParametersConsumed("addAction", params); auto rule = makeRule(var); someRuleActions->modify([&rule, &action, &uuid, creationOrder, &name](vector& ruleactions){ @@ -2109,22 +2109,12 @@ static void addAction(GlobalStateHolder > *someRuleActions, const luad typedef std::unordered_map > responseParams_t; -static void parseResponseConfig(boost::optional vars, ResponseConfig& config) +static void parseResponseConfig(boost::optional& vars, ResponseConfig& config) { - if (vars) { - if (vars->count("ttl")) { - config.ttl = boost::get((*vars)["ttl"]); - } - if (vars->count("aa")) { - config.setAA = boost::get((*vars)["aa"]); - } - if (vars->count("ad")) { - config.setAD = boost::get((*vars)["ad"]); - } - if (vars->count("ra")) { - config.setRA = boost::get((*vars)["ra"]); - } - } + getOptionalValue(vars, "ttl", config.ttl); + getOptionalValue(vars, "aa", config.setAA); + getOptionalValue(vars, "ad", config.setAD); + getOptionalValue(vars, "ra", config.setRA); } void setResponseHeadersFromConfig(dnsheader& dh, const ResponseConfig& config) @@ -2153,6 +2143,7 @@ void setupLuaActions(LuaContext& luaCtx) uint64_t creationOrder; std::string name; parseRuleParams(params, uuid, name, creationOrder); + checkAllParametersConsumed("newRuleAction", params); auto rule = makeRule(dnsrule); DNSDistRuleAction ra({std::move(rule), action, std::move(name), uuid, creationOrder}); @@ -2277,6 +2268,7 @@ void setupLuaActions(LuaContext& luaCtx) auto ret = std::shared_ptr(new SpoofAction(addrs)); auto sa = std::dynamic_pointer_cast(ret); parseResponseConfig(vars, sa->d_responseConfig); + checkAllParametersConsumed("SpoofAction", vars); return ret; }); @@ -2291,6 +2283,7 @@ void setupLuaActions(LuaContext& luaCtx) auto ret = std::shared_ptr(new SpoofAction(DNSName(a))); auto sa = std::dynamic_pointer_cast(ret); parseResponseConfig(vars, sa->d_responseConfig); + checkAllParametersConsumed("SpoofCNAMEAction", vars); return ret; }); @@ -2308,6 +2301,7 @@ void setupLuaActions(LuaContext& luaCtx) auto ret = std::shared_ptr(new SpoofAction(raws)); auto sa = std::dynamic_pointer_cast(ret); parseResponseConfig(vars, sa->d_responseConfig); + checkAllParametersConsumed("SpoofRawAction", vars); return ret; }); @@ -2402,6 +2396,7 @@ void setupLuaActions(LuaContext& luaCtx) auto ret = std::shared_ptr(new RCodeAction(rcode)); auto rca = std::dynamic_pointer_cast(ret); parseResponseConfig(vars, rca->d_responseConfig); + checkAllParametersConsumed("RCodeAction", vars); return ret; }); @@ -2409,6 +2404,7 @@ void setupLuaActions(LuaContext& luaCtx) auto ret = std::shared_ptr(new ERCodeAction(rcode)); auto erca = std::dynamic_pointer_cast(ret); parseResponseConfig(vars, erca->d_responseConfig); + checkAllParametersConsumed("ERCodeAction", vars); return ret; }); @@ -2464,14 +2460,9 @@ void setupLuaActions(LuaContext& luaCtx) std::string serverID; std::string ipEncryptKey; - if (vars) { - if (vars->count("serverID")) { - serverID = boost::get((*vars)["serverID"]); - } - if (vars->count("ipEncryptKey")) { - ipEncryptKey = boost::get((*vars)["ipEncryptKey"]); - } - } + getOptionalValue(vars, "serverID", serverID); + getOptionalValue(vars, "ipEncryptKey", ipEncryptKey); + checkAllParametersConsumed("RemoteLogAction", vars); return std::shared_ptr(new RemoteLogAction(logger, alterFunc, serverID, ipEncryptKey)); }); @@ -2488,14 +2479,9 @@ void setupLuaActions(LuaContext& luaCtx) std::string serverID; std::string ipEncryptKey; - if (vars) { - if (vars->count("serverID")) { - serverID = boost::get((*vars)["serverID"]); - } - if (vars->count("ipEncryptKey")) { - ipEncryptKey = boost::get((*vars)["ipEncryptKey"]); - } - } + getOptionalValue(vars, "serverID", serverID); + getOptionalValue(vars, "ipEncryptKey", ipEncryptKey); + checkAllParametersConsumed("RemoteLogResponseAction", vars); return std::shared_ptr(new RemoteLogResponseAction(logger, alterFunc, serverID, ipEncryptKey, includeCNAME ? *includeCNAME : false)); }); @@ -2564,6 +2550,7 @@ void setupLuaActions(LuaContext& luaCtx) auto ret = std::shared_ptr(new HTTPStatusAction(status, PacketBuffer(body.begin(), body.end()), contentType ? *contentType : "")); auto hsa = std::dynamic_pointer_cast(ret); parseResponseConfig(vars, hsa->d_responseConfig); + checkAllParametersConsumed("HTTPStatusAction", vars); return ret; }); #endif /* HAVE_DNS_OVER_HTTPS */ @@ -2580,17 +2567,13 @@ void setupLuaActions(LuaContext& luaCtx) luaCtx.writeFunction("NegativeAndSOAAction", [](bool nxd, const std::string& zone, uint32_t ttl, const std::string& mname, const std::string& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum, boost::optional vars) { bool soaInAuthoritySection = false; - if (vars) { - if (vars->count("soaInAuthoritySection")) { - soaInAuthoritySection = boost::get((*vars)["soaInAuthoritySection"]); - } - } - + getOptionalValue(vars, "soaInAuthoritySection", soaInAuthoritySection); auto ret = std::shared_ptr(new NegativeAndSOAAction(nxd, DNSName(zone), ttl, DNSName(mname), DNSName(rname), serial, refresh, retry, expire, minimum, soaInAuthoritySection)); auto action = std::dynamic_pointer_cast(ret); parseResponseConfig(vars, action->d_responseConfig); + checkAllParametersConsumed("NegativeAndSOAAction", vars); return ret; - }); + }); luaCtx.writeFunction("SetProxyProtocolValuesAction", [](const std::vector>& values) { return std::shared_ptr(new SetProxyProtocolValuesAction(values)); diff --git a/pdns/dnsdist-lua-rules.cc b/pdns/dnsdist-lua-rules.cc index 0506492e00..67740f9e03 100644 --- a/pdns/dnsdist-lua-rules.cc +++ b/pdns/dnsdist-lua-rules.cc @@ -73,14 +73,8 @@ void parseRuleParams(boost::optional params, boost::uuids::uuid string uuidStr; - if (params) { - if (params->count("uuid")) { - uuidStr = params->at("uuid"); - } - if (params->count("name")) { - name = params->at("name"); - } - } + getOptionalValue(params, "uuid", uuidStr); + getOptionalValue(params, "name", name); uuid = makeRuleID(uuidStr); creationOrder = s_creationOrder++; @@ -96,14 +90,9 @@ static std::string rulesToString(const std::vector& rules, boost::optionalcount("showUUIDs")) { - showUUIDs = boost::get((*vars)["showUUIDs"]); - } - if (vars->count("truncateRuleWidth")) { - truncateRuleWidth = boost::get((*vars)["truncateRuleWidth"]); - } - } + getOptionalValue(vars, "showUUIDs", showUUIDs); + getOptionalValue(vars, "truncateRuleWidth", truncateRuleWidth); + checkAllParametersConsumed("rulesToString", vars); if (showUUIDs) { boost::format fmt("%-3d %-30s %-38s %9d %9d %-56s %s\n"); diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index 2df956334d..ba20ce302a 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -107,29 +107,19 @@ void resetLuaSideEffect() using localbind_t = LuaAssociativeTable, LuaArray, LuaAssociativeTable>>; -static void parseLocalBindVars(boost::optional vars, bool& reusePort, int& tcpFastOpenQueueSize, std::string& interface, std::set& cpus, int& tcpListenQueueSize, uint64_t& maxInFlightQueriesPerConnection, uint64_t& tcpMaxConcurrentConnections) +static void parseLocalBindVars(boost::optional& vars, bool& reusePort, int& tcpFastOpenQueueSize, std::string& interface, std::set& cpus, int& tcpListenQueueSize, uint64_t& maxInFlightQueriesPerConnection, uint64_t& tcpMaxConcurrentConnections) { if (vars) { - if (vars->count("reusePort")) { - reusePort = boost::get((*vars)["reusePort"]); - } - if (vars->count("tcpFastOpenQueueSize")) { - tcpFastOpenQueueSize = boost::get((*vars)["tcpFastOpenQueueSize"]); - } - if (vars->count("tcpListenQueueSize")) { - tcpListenQueueSize = boost::get((*vars)["tcpListenQueueSize"]); - } - if (vars->count("maxConcurrentTCPConnections")) { - tcpMaxConcurrentConnections = boost::get((*vars)["maxConcurrentTCPConnections"]); - } - if (vars->count("maxInFlight")) { - maxInFlightQueriesPerConnection = boost::get((*vars)["maxInFlight"]); - } - if (vars->count("interface")) { - interface = boost::get((*vars)["interface"]); - } - if (vars->count("cpus")) { - for (const auto& cpu : boost::get>((*vars)["cpus"])) { + LuaArray setCpus; + + getOptionalValue(vars, "reusePort", reusePort); + getOptionalValue(vars, "tcpFastOpenQueueSize", tcpFastOpenQueueSize); + getOptionalValue(vars, "tcpListenQueueSize", tcpListenQueueSize); + getOptionalValue(vars, "maxConcurrentTCPConnections", tcpMaxConcurrentConnections); + getOptionalValue(vars, "maxInFlight", maxInFlightQueriesPerConnection); + getOptionalValue(vars, "interface", interface); + if (getOptionalValue(vars, "cpus", setCpus) > 0) { + for (const auto& cpu : setCpus) { cpus.insert(cpu.second); } } @@ -181,65 +171,45 @@ static bool loadTLSCertificateAndKeys(const std::string& context, std::vector vars) +static void parseTLSConfig(TLSConfig& config, const std::string& context, boost::optional& vars) { - if (vars->count("ciphers")) { - config.d_ciphers = boost::get((*vars)["ciphers"]); - } - - if (vars->count("ciphersTLS13")) { - config.d_ciphers13 = boost::get((*vars)["ciphersTLS13"]); - } + getOptionalValue(vars, "ciphers", config.d_ciphers); + getOptionalValue(vars, "ciphersTLS13", config.d_ciphers13); #ifdef HAVE_LIBSSL - if (vars->count("minTLSVersion")) { - config.d_minTLSVersion = libssl_tls_version_from_string(boost::get((*vars)["minTLSVersion"])); - } + std::string minVersion; + if (getOptionalValue(vars, "minTLSVersion", minVersion) > 0) + config.d_minTLSVersion = libssl_tls_version_from_string(minVersion); +#else /* HAVE_LIBSSL */ + if (vars->erase("minTLSVersion") > 0) + warnlog("minTLSVersion has no effect with chosen TLS library"); #endif /* HAVE_LIBSSL */ - if (vars->count("ticketKeyFile")) { - config.d_ticketKeyFile = boost::get((*vars)["ticketKeyFile"]); - } - - if (vars->count("ticketsKeysRotationDelay")) { - config.d_ticketsKeyRotationDelay = boost::get((*vars)["ticketsKeysRotationDelay"]); - } - - if (vars->count("numberOfTicketsKeys")) { - config.d_numberOfTicketsKeys = boost::get((*vars)["numberOfTicketsKeys"]); - } - - if (vars->count("preferServerCiphers")) { - config.d_preferServerCiphers = boost::get((*vars)["preferServerCiphers"]); - } - - if (vars->count("sessionTimeout")) { - config.d_sessionTimeout = boost::get((*vars)["sessionTimeout"]); - } - - if (vars->count("sessionTickets")) { - config.d_enableTickets = boost::get((*vars)["sessionTickets"]); + getOptionalValue(vars, "ticketKeyFile", config.d_ticketKeyFile); + getOptionalValue(vars, "ticketsKeysRotationDelay", config.d_ticketsKeyRotationDelay); + getOptionalValue(vars, "numberOfTicketsKeys", config.d_numberOfTicketsKeys); + getOptionalValue(vars, "preferServerCiphers", config.d_preferServerCiphers); + getOptionalValue(vars, "sessionTimeout", config.d_sessionTimeout); + getOptionalValue(vars, "sessionTickets", config.d_enableTickets); + int numberOfStoredSessions{0}; + if (getOptionalValue(vars, "numberOfStoredSessions", numberOfStoredSessions) > 0) { + if (numberOfStoredSessions < 0) { + errlog("Invalid value '%d' for %s() parameter 'numberOfStoredSessions', should be >= 0, dismissing", numberOfStoredSessions, context); + g_outputBuffer = "Invalid value '" + std::to_string(numberOfStoredSessions) + "' for " + context + "() parameter 'numberOfStoredSessions', should be >= 0, dimissing"; + } + config.d_maxStoredSessions = numberOfStoredSessions; } - if (vars->count("numberOfStoredSessions")) { - auto value = boost::get((*vars)["numberOfStoredSessions"]); - if (value < 0) { - errlog("Invalid value '%d' for %s() parameter 'numberOfStoredSessions', should be >= 0, dismissing", value, context); - g_outputBuffer = "Invalid value '" + std::to_string(value) + "' for " + context + "() parameter 'numberOfStoredSessions', should be >= 0, dismissing"; - } - config.d_maxStoredSessions = value; - } - - if (vars->count("ocspResponses")) { - auto files = boost::get>((*vars)["ocspResponses"]); + LuaArray files; + if (getOptionalValue(vars, "ocspResponses", files) > 0) { for (const auto& file : files) { config.d_ocspFiles.push_back(file.second); } } - if (vars->count("keyLogFile")) { + if (vars->count("keyLogFile") > 0) { #ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK - config.d_keyLogFile = boost::get((*vars)["keyLogFile"]); + getOptionalValue(vars, "keyLogFile", config.d_keyLogFile); #else errlog("TLS Key logging has been enabled using the 'keyLogFile' parameter to %s(), but this version of OpenSSL does not support it", context); g_outputBuffer = "TLS Key logging has been enabled using the 'keyLogFile' parameter to " + context + "(), but this version of OpenSSL does not support it"; @@ -349,6 +319,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) serverAddressStr = boost::get(vars["address"]); } + // FIXME: Check vars for unknown keys, needs refactoring to move creation at the end */ + if (vars.count("source")) { /* handle source in the following forms: - v4 address ("192.0.2.1") @@ -374,7 +346,6 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) /* try to parse as interface name, or v4/v6@itf */ config.sourceItfName = source.substr(pos == std::string::npos ? 0 : pos + 1); unsigned int itfIdx = if_nametoindex(config.sourceItfName.c_str()); - if (itfIdx != 0) { if (pos == 0 || pos == std::string::npos) { /* "eth0" or "@eth0" */ @@ -848,6 +819,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConnections); + checkAllParametersConsumed("setLocal", vars); + try { ComboAddress loc(addr, 53); for (auto it = g_frontends.begin(); it != g_frontends.end();) { @@ -897,6 +870,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) std::set cpus; parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConnections); + checkAllParametersConsumed("addLocal", vars); try { ComboAddress loc(addr, 53); @@ -992,11 +966,9 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) luaCtx.writeFunction("showServers", [](boost::optional vars) { setLuaNoSideEffect(); bool showUUIDs = false; - if (vars) { - if (vars->count("showUUIDs")) { - showUUIDs = boost::get((*vars)["showUUIDs"]); - } - } + getOptionalValue(vars, "showUUIDs", showUUIDs); + checkAllParametersConsumed("showServers", vars); + try { ostringstream ret; boost::format fmt; @@ -1150,8 +1122,16 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) hashPlaintextCredentials = boost::get(vars->at("hashPlaintextCredentials")); } - if (vars->count("password")) { - std::string password = boost::get(vars->at("password")); + std::string password; + std::string apiKey; + std::string acl; + LuaAssociativeTable headers; + bool statsRequireAuthentication{true}; + bool apiRequiresAuthentication{true}; + bool dashboardRequiresAuthentication{true}; + std::string maxConcurrentConnections; + + if (getOptionalValue(vars, "password", password) > 0) { auto holder = make_unique(std::move(password), hashPlaintextCredentials); if (!holder->wasHashed() && holder->isHashingAvailable()) { infolog("Passing a plain-text password via the 'password' parameter to 'setWebserverConfig()' is not advised, please consider generating a hashed one using 'hashPassword()' instead."); @@ -1160,8 +1140,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) setWebserverPassword(std::move(holder)); } - if (vars->count("apiKey")) { - std::string apiKey = boost::get(vars->at("apiKey")); + if (getOptionalValue(vars, "apiKey", apiKey) > 0) { auto holder = make_unique(std::move(apiKey), hashPlaintextCredentials); if (!holder->wasHashed() && holder->isHashingAvailable()) { infolog("Passing a plain-text API key via the 'apiKey' parameter to 'setWebserverConfig()' is not advised, please consider generating a hashed one using 'hashPassword()' instead."); @@ -1170,32 +1149,28 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) setWebserverAPIKey(std::move(holder)); } - if (vars->count("acl")) { - const std::string acl = boost::get(vars->at("acl")); - + if (getOptionalValue(vars, "acl", acl) > 0) { setWebserverACL(acl); } - if (vars->count("customHeaders")) { - const auto headers = boost::get>(vars->at("customHeaders")); - + if (getOptionalValue(vars, "customHeaders", headers) > 0) { setWebserverCustomHeaders(headers); } - if (vars->count("statsRequireAuthentication")) { - setWebserverStatsRequireAuthentication(boost::get(vars->at("statsRequireAuthentication"))); + if (getOptionalValue(vars, "statsRequireAuthentication", statsRequireAuthentication) > 0) { + setWebserverStatsRequireAuthentication(statsRequireAuthentication); } - if (vars->count("apiRequiresAuthentication")) { - setWebserverAPIRequiresAuthentication(boost::get(vars->at("apiRequiresAuthentication"))); + if (getOptionalValue(vars, "apiRequiresAuthentication", apiRequiresAuthentication) > 0) { + setWebserverAPIRequiresAuthentication(apiRequiresAuthentication); } - if (vars->count("dashboardRequiresAuthentication")) { - setWebserverDashboardRequiresAuthentication(boost::get(vars->at("dashboardRequiresAuthentication"))); + if (getOptionalValue(vars, "dashboardRequiresAuthentication", dashboardRequiresAuthentication) > 0) { + setWebserverDashboardRequiresAuthentication(dashboardRequiresAuthentication); } - if (vars->count("maxConcurrentConnections")) { - setWebserverMaxConcurrentConnections(std::stoi(boost::get(vars->at("maxConcurrentConnections")))); + if (getOptionalValue(vars, "maxConcurrentConnections", maxConcurrentConnections) > 0) { + setWebserverMaxConcurrentConnections(std::stoi(maxConcurrentConnections)); } }); @@ -1675,6 +1650,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) std::vector certKeys; parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConnections); + checkAllParametersConsumed("addDNSCryptBind", vars); if (certFiles.type() == typeid(std::string) && keyFiles.type() == typeid(std::string)) { auto certFile = boost::get(certFiles); @@ -2524,43 +2500,25 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) if (vars) { parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConnections); - - if (vars->count("idleTimeout")) { - frontend->d_idleTimeout = boost::get((*vars)["idleTimeout"]); - } - - if (vars->count("serverTokens")) { - frontend->d_serverTokens = boost::get((*vars)["serverTokens"]); - } - - if (vars->count("customResponseHeaders")) { - for (auto const& headerMap : boost::get>((*vars).at("customResponseHeaders"))) { - frontend->d_customResponseHeaders[boost::to_lower_copy(headerMap.first)] = headerMap.second; + getOptionalValue(vars, "idleTimeout", frontend->d_idleTimeout); + getOptionalValue(vars, "serverTokens", frontend->d_serverTokens); + + LuaAssociativeTable customResponseHeaders; + if (getOptionalValue(vars, "customResponseHeaders", customResponseHeaders) > 0) { + for (auto const& headerMap : customResponseHeaders) { + std::pair headerResponse = std::make_pair(boost::to_lower_copy(headerMap.first), headerMap.second); + frontend->d_customResponseHeaders.insert(headerResponse); } } - if (vars->count("sendCacheControlHeaders")) { - frontend->d_sendCacheControlHeaders = boost::get((*vars)["sendCacheControlHeaders"]); - } - - if (vars->count("keepIncomingHeaders")) { - frontend->d_keepIncomingHeaders = boost::get((*vars)["keepIncomingHeaders"]); - } - - if (vars->count("trustForwardedForHeader")) { - frontend->d_trustForwardedForHeader = boost::get((*vars)["trustForwardedForHeader"]); - } + getOptionalValue(vars, "sendCacheControlHeaders", frontend->d_sendCacheControlHeaders); + getOptionalValue(vars, "keepIncomingHeaders", frontend->d_keepIncomingHeaders); + getOptionalValue(vars, "trustForwardedForHeader", frontend->d_trustForwardedForHeader); + getOptionalValue(vars, "internalPipeBufferSize", frontend->d_internalPipeBufferSize); + getOptionalValue(vars, "exactPathMatching", frontend->d_exactPathMatching); - if (vars->count("internalPipeBufferSize")) { - frontend->d_internalPipeBufferSize = boost::get((*vars)["internalPipeBufferSize"]); - } - - if (vars->count("exactPathMatching")) { - frontend->d_exactPathMatching = boost::get((*vars)["exactPathMatching"]); - } - - if (vars->count("additionalAddresses")) { - auto addresses = boost::get>(vars->at("additionalAddresses")); + LuaArray addresses; + if (getOptionalValue(vars, "additionalAddresses", addresses) > 0) { for (const auto& [_, add] : addresses) { try { ComboAddress address(add); @@ -2574,20 +2532,22 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } parseTLSConfig(frontend->d_tlsConfig, "addDOHLocal", vars); - if (vars->count("ignoreTLSConfigurationErrors")) { - if (boost::get((*vars)["ignoreTLSConfigurationErrors"])) { - // we are asked to try to load the certificates so we can return a potential error - // and properly ignore the frontend before actually launching it - try { - std::map ocspResponses = {}; - auto ctx = libssl_init_server_context(frontend->d_tlsConfig, ocspResponses); - } - catch (const std::runtime_error& e) { - errlog("Ignoring DoH frontend: '%s'", e.what()); - return; - } + + bool ignoreTLSConfigurationErrors = false; + if (getOptionalValue(vars, "ignoreTLSConfigurationErrors", ignoreTLSConfigurationErrors) > 0 && ignoreTLSConfigurationErrors) { + // we are asked to try to load the certificates so we can return a potential error + // and properly ignore the frontend before actually launching it + try { + std::map ocspResponses = {}; + auto ctx = libssl_init_server_context(frontend->d_tlsConfig, ocspResponses); + } + catch (const std::runtime_error& e) { + errlog("Ignoring DoH frontend: '%s'", e.what()); + return; } } + + checkAllParametersConsumed("addDOHLocal", vars); } g_dohlocals.push_back(frontend); auto cs = std::make_unique(frontend->d_local, true, reusePort, tcpFastOpenQueueSize, interface, cpus); @@ -2735,13 +2695,6 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } }); - luaCtx.registerFunction::*)() const>("getAddressAndPort", [](const std::shared_ptr& frontend) { - if (frontend == nullptr) { - return std::string(); - } - return frontend->d_local.toStringWithPort(); - }); - luaCtx.writeFunction("addTLSLocal", [client](const std::string& addr, boost::variant, LuaArray, LuaArray>> certFiles, LuaTypeOrArrayOf keyFiles, boost::optional vars) { if (client) { return; @@ -2770,13 +2723,11 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) if (vars) { parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConns); - if (vars->count("provider")) { - frontend->d_provider = boost::get((*vars)["provider"]); - boost::algorithm::to_lower(frontend->d_provider); - } + getOptionalValue(vars, "provider", frontend->d_provider); + boost::algorithm::to_lower(frontend->d_provider); - if (vars->count("additionalAddresses")) { - auto addresses = boost::get>(vars->at("additionalAddresses")); + LuaArray addresses; + if (getOptionalValue(vars, "additionalAddresses", addresses) > 0) { for (const auto& [_, add] : addresses) { try { ComboAddress address(add); @@ -2790,20 +2741,22 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } parseTLSConfig(frontend->d_tlsConfig, "addTLSLocal", vars); - if (vars->count("ignoreTLSConfigurationErrors")) { - if (boost::get((*vars)["ignoreTLSConfigurationErrors"])) { - // we are asked to try to load the certificates so we can return a potential error - // and properly ignore the frontend before actually launching it - try { - std::map ocspResponses = {}; - auto ctx = libssl_init_server_context(frontend->d_tlsConfig, ocspResponses); - } - catch (const std::runtime_error& e) { - errlog("Ignoring TLS frontend: '%s'", e.what()); - return; - } + + bool ignoreTLSConfigurationErrors = false; + if (getOptionalValue(vars, "ignoreTLSConfigurationErrors", ignoreTLSConfigurationErrors) > 0 && ignoreTLSConfigurationErrors) { + // we are asked to try to load the certificates so we can return a potential error + // and properly ignore the frontend before actually launching it + try { + std::map ocspResponses = {}; + auto ctx = libssl_init_server_context(frontend->d_tlsConfig, ocspResponses); + } + catch (const std::runtime_error& e) { + errlog("Ignoring TLS frontend: '%s'", e.what()); + return; } } + + checkAllParametersConsumed("addTLSLocal", vars); } try { diff --git a/pdns/dnsdist-lua.hh b/pdns/dnsdist-lua.hh index 9570f3e802..a477bd5e06 100644 --- a/pdns/dnsdist-lua.hh +++ b/pdns/dnsdist-lua.hh @@ -21,6 +21,7 @@ */ #pragma once +#include "dolog.hh" #include "dnsdist.hh" #include "dnsparser.hh" #include @@ -179,3 +180,40 @@ void setupLuaInspection(LuaContext& luaCtx); void setupLuaVars(LuaContext& luaCtx); void setupLuaWeb(LuaContext& luaCtx); void setupLuaLoadBalancingContext(LuaContext& luaCtx); + +/** + * getOptionalValue(vars, key, value) + * + * Attempts to extract value for key in vars. + * Erases the key from vars. + * + * returns: -1 if type wasn't compatible, 0 if not found or number of element(s) found + */ +template +static inline int getOptionalValue(boost::optional& vars, const std::string& key, T& value) { + /* nothing found, nothing to return */ + if (!vars) { + return 0; + } + + if (vars->count(key)) { + try { + value = boost::get((*vars)[key]); + } catch (const boost::bad_get& e) { + /* key is there but isn't compatible */ + return -1; + } + } + return vars->erase(key); +} + +template +static inline void checkAllParametersConsumed(const std::string& func, const boost::optional& vars) { + /* no vars */ + if (!vars) { + return; + } + for (const auto& [key, value] : *vars) { + warnlog("%s: Unknown key '%s' given - ignored", func, key); + } +} diff --git a/pdns/dnsdistdist/dnsdist-lua-bindings-packetcache.cc b/pdns/dnsdistdist/dnsdist-lua-bindings-packetcache.cc index 184ff33c28..fd62eb5318 100644 --- a/pdns/dnsdistdist/dnsdist-lua-bindings-packetcache.cc +++ b/pdns/dnsdistdist/dnsdist-lua-bindings-packetcache.cc @@ -24,7 +24,6 @@ #include #include "config.h" -#include "dolog.hh" #include "dnsdist.hh" #include "dnsdist-lua.hh" @@ -35,96 +34,62 @@ void setupLuaBindingsPacketCache(LuaContext& luaCtx, bool client) /* PacketCache */ luaCtx.writeFunction("newPacketCache", [client](size_t maxEntries, boost::optional>>> vars) { - bool keepStaleData = false; - size_t maxTTL = 86400; - size_t minTTL = 0; - size_t tempFailTTL = 60; - size_t maxNegativeTTL = 3600; - size_t staleTTL = 60; - size_t numberOfShards = 20; - bool dontAge = false; - bool deferrableInsertLock = true; - bool ecsParsing = false; - std::unordered_set optionsToSkip{EDNSOptionCode::COOKIE}; - - if (vars) { - - if (vars->count("deferrableInsertLock")) { - deferrableInsertLock = boost::get((*vars)["deferrableInsertLock"]); - } - - if (vars->count("dontAge")) { - dontAge = boost::get((*vars)["dontAge"]); - } - - if (vars->count("keepStaleData")) { - keepStaleData = boost::get((*vars)["keepStaleData"]); - } - - if (vars->count("maxNegativeTTL")) { - maxNegativeTTL = boost::get((*vars)["maxNegativeTTL"]); - } - - if (vars->count("maxTTL")) { - maxTTL = boost::get((*vars)["maxTTL"]); - } - - if (vars->count("minTTL")) { - minTTL = boost::get((*vars)["minTTL"]); - } - - if (vars->count("numberOfShards")) { - numberOfShards = boost::get((*vars)["numberOfShards"]); - } - - if (vars->count("parseECS")) { - ecsParsing = boost::get((*vars)["parseECS"]); - } - - if (vars->count("staleTTL")) { - staleTTL = boost::get((*vars)["staleTTL"]); - } - - if (vars->count("temporaryFailureTTL")) { - tempFailTTL = boost::get((*vars)["temporaryFailureTTL"]); - } - - if (vars->count("cookieHashing")) { - if (boost::get((*vars)["cookieHashing"])) { - optionsToSkip.erase(EDNSOptionCode::COOKIE); - } - } - if (vars->count("skipOptions")) { - for (const auto& option: boost::get>(vars->at("skipOptions"))) { - optionsToSkip.insert(option.second); - } - } + bool keepStaleData = false; + size_t maxTTL = 86400; + size_t minTTL = 0; + size_t tempFailTTL = 60; + size_t maxNegativeTTL = 3600; + size_t staleTTL = 60; + size_t numberOfShards = 20; + bool dontAge = false; + bool deferrableInsertLock = true; + bool ecsParsing = false; + bool cookieHashing = false; + LuaArray skipOptions; + std::unordered_set optionsToSkip{EDNSOptionCode::COOKIE}; + + getOptionalValue(vars, "deferrableInsertLock", deferrableInsertLock); + getOptionalValue(vars, "dontAge", dontAge); + getOptionalValue(vars, "keepStaleData", keepStaleData); + getOptionalValue(vars, "maxNegativeTTL", maxNegativeTTL); + getOptionalValue(vars, "maxTTL", maxTTL); + getOptionalValue(vars, "minTTL", minTTL); + getOptionalValue(vars, "numberOfShards", numberOfShards); + getOptionalValue(vars, "parseECS", ecsParsing); + getOptionalValue(vars, "staleTTL", staleTTL); + getOptionalValue(vars, "temporaryFailureTTL", tempFailTTL); + getOptionalValue(vars, "cookieHashing", cookieHashing); + + if (getOptionalValue(vars, "skipOptions", skipOptions) > 0) { + for (const auto& option : skipOptions) { + optionsToSkip.insert(option.second); } + } - if (maxEntries == 0) { - warnlog("The number of entries in the packet cache is set to 0, raising to 1"); - g_outputBuffer += "The number of entries in the packet cache is set to 0, raising to 1"; - maxEntries = 1; - } + if (cookieHashing) { + optionsToSkip.erase(EDNSOptionCode::COOKIE); + } - if (maxEntries < numberOfShards) { - warnlog("The number of entries (%d) in the packet cache is smaller than the number of shards (%d), decreasing the number of shards to %d", maxEntries, numberOfShards, maxEntries); - g_outputBuffer += "The number of entries (" + std::to_string(maxEntries) + " in the packet cache is smaller than the number of shards (" + std::to_string(numberOfShards) + "), decreasing the number of shards to " + std::to_string(maxEntries); - numberOfShards = maxEntries; - } + checkAllParametersConsumed("newPacketCache", vars); - if (client) { - maxEntries = 1; - numberOfShards = 1; - } + if (maxEntries < numberOfShards) { + warnlog("The number of entries (%d) in the packet cache is smaller than the number of shards (%d), decreasing the number of shards to %d", maxEntries, numberOfShards, maxEntries); + g_outputBuffer += "The number of entries (" + std::to_string(maxEntries) + " in the packet cache is smaller than the number of shards (" + std::to_string(numberOfShards) + "), decreasing the number of shards to " + std::to_string(maxEntries); + numberOfShards = maxEntries; + } - auto res = std::make_shared(maxEntries, maxTTL, minTTL, tempFailTTL, maxNegativeTTL, staleTTL, dontAge, numberOfShards, deferrableInsertLock, ecsParsing); + if (client) { + maxEntries = 1; + numberOfShards = 1; + } - res->setKeepStaleData(keepStaleData); - res->setSkippedOptions(optionsToSkip); + auto res = std::make_shared(maxEntries, maxTTL, minTTL, tempFailTTL, maxNegativeTTL, staleTTL, dontAge, numberOfShards, deferrableInsertLock, ecsParsing); - return res; - }); + res->setKeepStaleData(keepStaleData); + res->setSkippedOptions(optionsToSkip); + + return res; + }); #ifndef DISABLE_PACKETCACHE_BINDINGS luaCtx.registerFunction::*)()const>("toString", [](const std::shared_ptr& cache) { diff --git a/pdns/dnsdistdist/dnsdist-lua-bindings-protobuf.cc b/pdns/dnsdistdist/dnsdist-lua-bindings-protobuf.cc index 6e470d34e8..e532a56828 100644 --- a/pdns/dnsdistdist/dnsdist-lua-bindings-protobuf.cc +++ b/pdns/dnsdistdist/dnsdist-lua-bindings-protobuf.cc @@ -32,7 +32,7 @@ #include "remote_logger.hh" #ifdef HAVE_FSTRM -static void parseFSTRMOptions(const boost::optional>& params, LuaAssociativeTable& options) +static void parseFSTRMOptions(boost::optional>& params, LuaAssociativeTable& options) { if (!params) { return; @@ -41,9 +41,7 @@ static void parseFSTRMOptions(const boost::optional const potentialOptions = { "bufferHint", "flushTimeout", "inputQueueSize", "outputQueueSize", "queueNotifyThreshold", "reopenInterval" }; for (const auto& potentialOption : potentialOptions) { - if (params->count(potentialOption)) { - options[potentialOption] = boost::get(params->at(potentialOption)); - } + getOptionalValue(params, potentialOption, options[potentialOption]); } } #endif /* HAVE_FSTRM */ @@ -138,6 +136,7 @@ void setupLuaBindingsProtoBuf(LuaContext& luaCtx, bool client, bool configCheck) LuaAssociativeTable options; parseFSTRMOptions(params, options); + checkAllParametersConsumed("newRemoteLogger", params); return std::shared_ptr(new FrameStreamLogger(AF_UNIX, address, !client, options)); #else throw std::runtime_error("fstrm support is required to build an AF_UNIX FrameStreamLogger"); @@ -152,6 +151,7 @@ void setupLuaBindingsProtoBuf(LuaContext& luaCtx, bool client, bool configCheck) LuaAssociativeTable options; parseFSTRMOptions(params, options); + checkAllParametersConsumed("newFrameStreamTcpLogger", params); return std::shared_ptr(new FrameStreamLogger(AF_INET, address, !client, options)); #else throw std::runtime_error("fstrm with TCP support is required to build an AF_INET FrameStreamLogger");