From: Remi Gacogne Date: Tue, 7 Feb 2023 16:04:24 +0000 (+0100) Subject: dnsdist: Better handling of failed numerical conversions X-Git-Tag: dnsdist-1.8.0-rc1~43^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=refs%2Fpull%2F10115%2Fhead;p=thirdparty%2Fpdns.git dnsdist: Better handling of failed numerical conversions As suggested by Charles-Henri Bruyand (thanks!). --- diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index 63273c2bf2..4c517bc0aa 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -178,8 +178,9 @@ static void parseTLSConfig(TLSConfig& config, const std::string& context, boost: #ifdef HAVE_LIBSSL std::string minVersion; - if (getOptionalValue(vars, "minTLSVersion", minVersion) > 0) + if (getOptionalValue(vars, "minTLSVersion", minVersion) > 0) { config.d_minTLSVersion = libssl_tls_version_from_string(minVersion); + } #else /* HAVE_LIBSSL */ if (vars->erase("minTLSVersion") > 0) warnlog("minTLSVersion has no effect with chosen TLS library"); @@ -218,17 +219,9 @@ static void parseTLSConfig(TLSConfig& config, const std::string& context, boost: #endif } - if (vars->count("releaseBuffers")) { - config.d_releaseBuffers = boost::get((*vars)["releaseBuffers"]); - } - - if (vars->count("enableRenegotiation")) { - config.d_enableRenegotiation = boost::get((*vars)["enableRenegotiation"]); - } - - if (vars->count("tlsAsyncMode")) { - config.d_asyncMode = boost::get((*vars).at("tlsAsyncMode")); - } + getOptionalValue(vars, "releaseBuffers", config.d_releaseBuffers); + getOptionalValue(vars, "enableRenegotiation", config.d_enableRenegotiation); + getOptionalValue(vars, "tlsAsyncMode", config.d_asyncMode); } #endif // defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS) @@ -376,50 +369,23 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } } - if (getOptionalValue(vars, "qps", valueStr) > 0) { - config.d_qpsLimit = std::stoi(valueStr); - } - - if (getOptionalValue(vars, "order", valueStr) > 0) { - config.order = std::stoi(valueStr); - } - - if (getOptionalValue(vars, "weight", valueStr) > 0) { - try { - config.d_weight = std::stoi(valueStr); - - if (config.d_weight < 1) { - errlog("Error creating new server: downstream weight value must be greater than 0."); - return std::shared_ptr(); - } - } - catch (const std::exception& e) { - // std::stoi will throw an exception if the string isn't in a value int range - errlog("Error creating new server: downstream weight value must be between %s and %s", 1, std::numeric_limits::max()); - return std::shared_ptr(); - } + getOptionalIntegerValue("newServer", vars, "qps", config.d_qpsLimit); + getOptionalIntegerValue("newServer", vars, "order", config.order); + getOptionalIntegerValue("newServer", vars, "weight", config.d_weight); + if (config.d_weight < 1) { + errlog("Error creating new server: downstream weight value must be greater than 0."); + return std::shared_ptr(); } - if (getOptionalValue(vars, "retries", valueStr) > 0) { - config.d_retries = std::stoi(valueStr); - } + getOptionalIntegerValue("newServer", vars, "retries", config.d_retries); + getOptionalIntegerValue("newServer", vars, "tcpConnectTimeout", config.tcpConnectTimeout); + getOptionalIntegerValue("newServer", vars, "tcpSendTimeout", config.tcpSendTimeout); + getOptionalIntegerValue("newServer", vars, "tcpRecvTimeout", config.tcpRecvTimeout); if (getOptionalValue(vars, "checkInterval", valueStr) > 0) { config.checkInterval = static_cast(std::stoul(valueStr)); } - if (getOptionalValue(vars, "tcpConnectTimeout", valueStr) > 0) { - config.tcpConnectTimeout = std::stoi(boost::get(valueStr)); - } - - if (getOptionalValue(vars, "tcpSendTimeout", valueStr) > 0) { - config.tcpSendTimeout = std::stoi(valueStr); - } - - if (getOptionalValue(vars, "tcpRecvTimeout", valueStr) > 0) { - config.tcpRecvTimeout = std::stoi(valueStr); - } - bool fastOpen{false}; if (getOptionalValue(vars, "tcpFastOpen", fastOpen) > 0) { if (fastOpen) { @@ -431,17 +397,10 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } } - if (getOptionalValue(vars, "maxInFlight", valueStr) > 0) { - config.d_maxInFlightQueriesPerConn = std::stoi(valueStr); - } + getOptionalIntegerValue("newServer", vars, "maxInFlight", config.d_maxInFlightQueriesPerConn); + getOptionalIntegerValue("newServer", vars, "maxConcurrentTCPConnections", config.d_tcpConcurrentConnectionsLimit); - if (getOptionalValue(vars, "maxConcurrentTCPConnections", valueStr) > 0) { - config.d_tcpConcurrentConnectionsLimit = std::stoi(valueStr); - } - - if (getOptionalValue(vars, "name", valueStr) > 0) { - config.name = valueStr; - } + getOptionalValue(vars, "name", config.name); if (getOptionalValue(vars, "id", valueStr) > 0) { config.id = boost::uuids::string_generator()(valueStr); @@ -471,17 +430,9 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } getOptionalValue(vars, "checkType", config.checkType); - - if (getOptionalValue(vars, "checkClass", valueStr) > 0) { - config.checkClass = std::stoi(valueStr); - } - + getOptionalIntegerValue("newServer", vars, "checkClass", config.checkClass); getOptionalValue(vars, "checkFunction", config.checkFunction); - - if (getOptionalValue(vars, "checkTimeout", valueStr) > 0) { - config.checkTimeout = std::stoi(valueStr); - } - + getOptionalIntegerValue("newServer", vars, "checkTimeout", config.checkTimeout); getOptionalValue(vars, "checkTCP", config.d_tcpCheck); getOptionalValue(vars, "setCD", config.setCD); getOptionalValue(vars, "mustResolve", config.mustResolve); @@ -538,32 +489,9 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) getOptionalValue(vars, "disableZeroScoping", config.disableZeroScope); getOptionalValue(vars, "ipBindAddrNoPort", config.ipBindAddrNoPort); - if (getOptionalValue(vars, "addXPF", valueStr) > 0) { - try { - config.xpfRRCode = std::stoi(valueStr); - } - catch (const std::exception& e) { - warnlog("addXPF must be integer, not '%s' - ignoring", valueStr); - } - } - - if (getOptionalValue(vars, "maxCheckFailures", valueStr) > 0) { - try { - config.maxCheckFailures = std::stoi(valueStr); - } - catch (const std::exception& e) { - warnlog("maxCheckFailures must be integer, not '%s' - ignoring", valueStr); - } - } - - if (getOptionalValue(vars, "rise", valueStr) > 0) { - try { - config.minRiseSuccesses = std::stoi(valueStr); - } - catch (const std::exception& e) { - warnlog("rise must be integer, not '%s' - ignoring", valueStr); - } - } + getOptionalIntegerValue("newServer", vars, "addXPF", config.xpfRRCode); + getOptionalIntegerValue("newServer", vars, "maxCheckFailures", config.maxCheckFailures); + getOptionalIntegerValue("newServer", vars, "rise", config.minRiseSuccesses); getOptionalValue(vars, "reconnectOnUp", config.reconnectOnUp); @@ -637,7 +565,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } LuaArray pools; - if (getOptionalValue(vars, "pool", valueStr) > 0) { + if (getOptionalValue(vars, "pool", valueStr, false) > 0) { config.pools.insert(valueStr); } else if (getOptionalValue(vars, "pool", pools) > 0) { @@ -1083,9 +1011,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } bool hashPlaintextCredentials = false; - if (vars->count("hashPlaintextCredentials")) { - hashPlaintextCredentials = boost::get(vars->at("hashPlaintextCredentials")); - } + getOptionalValue(vars, "hashPlaintextCredentials", hashPlaintextCredentials); std::string password; std::string apiKey; @@ -1094,7 +1020,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) bool statsRequireAuthentication{true}; bool apiRequiresAuthentication{true}; bool dashboardRequiresAuthentication{true}; - std::string maxConcurrentConnections; + int maxConcurrentConnections = 0; if (getOptionalValue(vars, "password", password) > 0) { auto holder = make_unique(std::move(password), hashPlaintextCredentials); @@ -1134,8 +1060,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) setWebserverDashboardRequiresAuthentication(dashboardRequiresAuthentication); } - if (getOptionalValue(vars, "maxConcurrentConnections", maxConcurrentConnections) > 0) { - setWebserverMaxConcurrentConnections(std::stoi(maxConcurrentConnections)); + if (getOptionalIntegerValue("setWebserverConfig", vars, "maxConcurrentConnections", maxConcurrentConnections) > 0) { + setWebserverMaxConcurrentConnections(maxConcurrentConnections); } }); diff --git a/pdns/dnsdist-lua.hh b/pdns/dnsdist-lua.hh index a477bd5e06..8b64db3ab9 100644 --- a/pdns/dnsdist-lua.hh +++ b/pdns/dnsdist-lua.hh @@ -190,7 +190,7 @@ void setupLuaLoadBalancingContext(LuaContext& luaCtx); * returns: -1 if type wasn't compatible, 0 if not found or number of element(s) found */ template -static inline int getOptionalValue(boost::optional& vars, const std::string& key, T& value) { +static inline int getOptionalValue(boost::optional& vars, const std::string& key, T& value, bool warnOnWrongType = true) { /* nothing found, nothing to return */ if (!vars) { return 0; @@ -201,12 +201,32 @@ static inline int getOptionalValue(boost::optional& vars, const std::string& value = boost::get((*vars)[key]); } catch (const boost::bad_get& e) { /* key is there but isn't compatible */ + if (warnOnWrongType) { + warnlog("Invalid type for key '%s' - ignored", key); + vars->erase(key); + } return -1; } } return vars->erase(key); } +template +static inline int getOptionalIntegerValue(const std::string& func, boost::optional& vars, const std::string& key, T& value) { + std::string valueStr; + auto ret = getOptionalValue(vars, key, valueStr, true); + if (ret == 1) { + try { + value = std::stoi(valueStr); + } + catch (const std::exception& e) { + warnlog("Parameter '%s' of '%s' must be integer, not '%s' - ignoring", func, key, valueStr); + return -1; + } + } + return ret; +} + template static inline void checkAllParametersConsumed(const std::string& func, const boost::optional& vars) { /* no vars */ diff --git a/regression-tests.dnsdist/test_Routing.py b/regression-tests.dnsdist/test_Routing.py index fbe5a42320..9d45803fca 100644 --- a/regression-tests.dnsdist/test_Routing.py +++ b/regression-tests.dnsdist/test_Routing.py @@ -735,30 +735,3 @@ class TestRoutingHighValueWRandom(DNSDistTest): self.assertEqual(self._responsesCounter['UDP Responder 2'], numberOfQueries - self._responsesCounter['UDP Responder']) if 'TCP Responder 2' in self._responsesCounter: self.assertEqual(self._responsesCounter['TCP Responder 2'], numberOfQueries - self._responsesCounter['TCP Responder']) - -class TestRoutingBadWeightWRandom(DNSDistTest): - - _testServer2Port = 5351 - _consoleKey = DNSDistTest.generateConsoleKey() - _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii') - _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_testServer2Port'] - _config_template = """ - setKey("%s") - controlSocket("127.0.0.1:%s") - setServerPolicy(wrandom) - s1 = newServer{address="127.0.0.1:%s", weight=-1} - s2 = newServer{address="127.0.0.1:%s", weight=2147483648} - """ - _checkConfigExpectedOutput = b"""Error creating new server: downstream weight value must be greater than 0. -Error creating new server: downstream weight value must be between 1 and 2147483647 -Configuration 'configs/dnsdist_TestRoutingBadWeightWRandom.conf' OK! -""" - - def testBadWeightWRandom(self): - """ - Routing: WRandom - - Test that downstreams cannot be added with invalid weights. - """ - # There should be no downstreams - self.assertTrue(self.sendConsoleCommand("getServer(0)").startswith("Error"))