]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Better handling of failed numerical conversions 10115/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 7 Feb 2023 16:04:24 +0000 (17:04 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 7 Feb 2023 16:04:24 +0000 (17:04 +0100)
As suggested by Charles-Henri Bruyand (thanks!).

pdns/dnsdist-lua.cc
pdns/dnsdist-lua.hh
regression-tests.dnsdist/test_Routing.py

index 63273c2bf2c5e4bf3f403576e018cf230d81aeae..4c517bc0aa46172910d749d254b996dd678e730e 100644 (file)
@@ -178,8 +178,9 @@ static void parseTLSConfig(TLSConfig& config, const std::string& context, boost:
 
 #ifdef HAVE_LIBSSL
   std::string minVersion;
-  if (getOptionalValue<std::string>(vars, "minTLSVersion", minVersion) > 0)
+  if (getOptionalValue<std::string>(vars, "minTLSVersion", minVersion) > 0) {
     config.d_minTLSVersion = libssl_tls_version_from_string(minVersion);
+  }
 #else /* HAVE_LIBSSL */
   if (vars->erase("minTLSVersion") > 0)
     warnlog("minTLSVersion has no effect with chosen TLS library");
@@ -218,17 +219,9 @@ static void parseTLSConfig(TLSConfig& config, const std::string& context, boost:
 #endif
   }
 
-  if (vars->count("releaseBuffers")) {
-    config.d_releaseBuffers = boost::get<bool>((*vars)["releaseBuffers"]);
-  }
-
-  if (vars->count("enableRenegotiation")) {
-    config.d_enableRenegotiation = boost::get<bool>((*vars)["enableRenegotiation"]);
-  }
-
-  if (vars->count("tlsAsyncMode")) {
-    config.d_asyncMode = boost::get<bool>((*vars).at("tlsAsyncMode"));
-  }
+  getOptionalValue<bool>(vars, "releaseBuffers", config.d_releaseBuffers);
+  getOptionalValue<bool>(vars, "enableRenegotiation", config.d_enableRenegotiation);
+  getOptionalValue<bool>(vars, "tlsAsyncMode", config.d_asyncMode);
 }
 
 #endif // defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
@@ -376,50 +369,23 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
                            }
                          }
 
-                         if (getOptionalValue<std::string>(vars, "qps", valueStr) > 0) {
-                           config.d_qpsLimit = std::stoi(valueStr);
-                         }
-
-                         if (getOptionalValue<std::string>(vars, "order", valueStr) > 0) {
-                           config.order = std::stoi(valueStr);
-                         }
-
-                         if (getOptionalValue<std::string>(vars, "weight", valueStr) > 0) {
-                           try {
-                             config.d_weight = std::stoi(valueStr);
-
-                             if (config.d_weight < 1) {
-                               errlog("Error creating new server: downstream weight value must be greater than 0.");
-                               return std::shared_ptr<DownstreamState>();
-                             }
-                           }
-                           catch (const std::exception& e) {
-                             // std::stoi will throw an exception if the string isn't in a value int range
-                             errlog("Error creating new server: downstream weight value must be between %s and %s", 1, std::numeric_limits<int>::max());
-                             return std::shared_ptr<DownstreamState>();
-                           }
+                         getOptionalIntegerValue("newServer", vars, "qps", config.d_qpsLimit);
+                         getOptionalIntegerValue("newServer", vars, "order", config.order);
+                         getOptionalIntegerValue("newServer", vars, "weight", config.d_weight);
+                         if (config.d_weight < 1) {
+                           errlog("Error creating new server: downstream weight value must be greater than 0.");
+                           return std::shared_ptr<DownstreamState>();
                          }
 
-                         if (getOptionalValue<std::string>(vars, "retries", valueStr) > 0) {
-                           config.d_retries = std::stoi(valueStr);
-                         }
+                         getOptionalIntegerValue("newServer", vars, "retries", config.d_retries);
+                         getOptionalIntegerValue("newServer", vars, "tcpConnectTimeout", config.tcpConnectTimeout);
+                         getOptionalIntegerValue("newServer", vars, "tcpSendTimeout", config.tcpSendTimeout);
+                         getOptionalIntegerValue("newServer", vars, "tcpRecvTimeout", config.tcpRecvTimeout);
 
                          if (getOptionalValue<std::string>(vars, "checkInterval", valueStr) > 0) {
                            config.checkInterval = static_cast<unsigned int>(std::stoul(valueStr));
                          }
 
-                         if (getOptionalValue<std::string>(vars, "tcpConnectTimeout", valueStr) > 0) {
-                           config.tcpConnectTimeout = std::stoi(boost::get<string>(valueStr));
-                         }
-
-                         if (getOptionalValue<std::string>(vars, "tcpSendTimeout", valueStr) > 0) {
-                           config.tcpSendTimeout = std::stoi(valueStr);
-                         }
-
-                         if (getOptionalValue<std::string>(vars, "tcpRecvTimeout", valueStr) > 0) {
-                           config.tcpRecvTimeout = std::stoi(valueStr);
-                         }
-
                          bool fastOpen{false};
                          if (getOptionalValue<bool>(vars, "tcpFastOpen", fastOpen) > 0) {
                            if (fastOpen) {
@@ -431,17 +397,10 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
                            }
                          }
 
-                         if (getOptionalValue<std::string>(vars, "maxInFlight", valueStr) > 0) {
-                           config.d_maxInFlightQueriesPerConn = std::stoi(valueStr);
-                         }
+                         getOptionalIntegerValue("newServer", vars, "maxInFlight", config.d_maxInFlightQueriesPerConn);
+                         getOptionalIntegerValue("newServer", vars, "maxConcurrentTCPConnections", config.d_tcpConcurrentConnectionsLimit);
 
-                         if (getOptionalValue<std::string>(vars, "maxConcurrentTCPConnections", valueStr) > 0) {
-                           config.d_tcpConcurrentConnectionsLimit = std::stoi(valueStr);
-                         }
-
-                         if (getOptionalValue<std::string>(vars, "name", valueStr) > 0) {
-                           config.name = valueStr;
-                         }
+                         getOptionalValue<std::string>(vars, "name", config.name);
 
                          if (getOptionalValue<std::string>(vars, "id", valueStr) > 0) {
                            config.id = boost::uuids::string_generator()(valueStr);
@@ -471,17 +430,9 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
                          }
 
                          getOptionalValue<std::string>(vars, "checkType", config.checkType);
-
-                         if (getOptionalValue<std::string>(vars, "checkClass", valueStr) > 0) {
-                           config.checkClass = std::stoi(valueStr);
-                         }
-
+                         getOptionalIntegerValue("newServer", vars, "checkClass", config.checkClass);
                          getOptionalValue<DownstreamState::checkfunc_t>(vars, "checkFunction", config.checkFunction);
-
-                         if (getOptionalValue<std::string>(vars, "checkTimeout", valueStr) > 0) {
-                           config.checkTimeout = std::stoi(valueStr);
-                         }
-
+                         getOptionalIntegerValue("newServer", vars, "checkTimeout", config.checkTimeout);
                          getOptionalValue<bool>(vars, "checkTCP", config.d_tcpCheck);
                          getOptionalValue<bool>(vars, "setCD", config.setCD);
                          getOptionalValue<bool>(vars, "mustResolve", config.mustResolve);
@@ -538,32 +489,9 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
                          getOptionalValue<bool>(vars, "disableZeroScoping", config.disableZeroScope);
                          getOptionalValue<bool>(vars, "ipBindAddrNoPort", config.ipBindAddrNoPort);
 
-                         if (getOptionalValue<std::string>(vars, "addXPF", valueStr) > 0) {
-                           try {
-                             config.xpfRRCode = std::stoi(valueStr);
-                           }
-                           catch (const std::exception& e) {
-                             warnlog("addXPF must be integer, not '%s' - ignoring", valueStr);
-                           }
-                         }
-
-                         if (getOptionalValue<std::string>(vars, "maxCheckFailures", valueStr) > 0) {
-                           try {
-                             config.maxCheckFailures = std::stoi(valueStr);
-                           }
-                           catch (const std::exception& e) {
-                             warnlog("maxCheckFailures must be integer, not '%s' - ignoring", valueStr);
-                           }
-                         }
-
-                         if (getOptionalValue<std::string>(vars, "rise", valueStr) > 0) {
-                           try {
-                             config.minRiseSuccesses = std::stoi(valueStr);
-                           }
-                           catch (const std::exception& e) {
-                             warnlog("rise must be integer, not '%s' - ignoring", valueStr);
-                           }
-                         }
+                         getOptionalIntegerValue("newServer", vars, "addXPF", config.xpfRRCode);
+                         getOptionalIntegerValue("newServer", vars, "maxCheckFailures", config.maxCheckFailures);
+                         getOptionalIntegerValue("newServer", vars, "rise", config.minRiseSuccesses);
 
                          getOptionalValue<bool>(vars, "reconnectOnUp", config.reconnectOnUp);
 
@@ -637,7 +565,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
                          }
 
                          LuaArray<std::string> pools;
-                         if (getOptionalValue<std::string>(vars, "pool", valueStr) > 0) {
+                         if (getOptionalValue<std::string>(vars, "pool", valueStr, false) > 0) {
                            config.pools.insert(valueStr);
                          }
                          else if (getOptionalValue<decltype(pools)>(vars, "pool", pools) > 0) {
@@ -1083,9 +1011,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     }
 
     bool hashPlaintextCredentials = false;
-    if (vars->count("hashPlaintextCredentials")) {
-      hashPlaintextCredentials = boost::get<bool>(vars->at("hashPlaintextCredentials"));
-    }
+    getOptionalValue<bool>(vars, "hashPlaintextCredentials", hashPlaintextCredentials);
 
     std::string password;
     std::string apiKey;
@@ -1094,7 +1020,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     bool statsRequireAuthentication{true};
     bool apiRequiresAuthentication{true};
     bool dashboardRequiresAuthentication{true};
-    std::string maxConcurrentConnections;
+    int maxConcurrentConnections = 0;
 
     if (getOptionalValue<std::string>(vars, "password", password) > 0) {
       auto holder = make_unique<CredentialsHolder>(std::move(password), hashPlaintextCredentials);
@@ -1134,8 +1060,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       setWebserverDashboardRequiresAuthentication(dashboardRequiresAuthentication);
     }
 
-    if (getOptionalValue<std::string>(vars, "maxConcurrentConnections", maxConcurrentConnections) > 0) {
-      setWebserverMaxConcurrentConnections(std::stoi(maxConcurrentConnections));
+    if (getOptionalIntegerValue("setWebserverConfig", vars, "maxConcurrentConnections", maxConcurrentConnections) > 0) {
+      setWebserverMaxConcurrentConnections(maxConcurrentConnections);
     }
   });
 
index a477bd5e06c34b929d8a1d99b12b2c0f7b11d4bf..8b64db3ab9df9c08b604d6d56dba66eec8ff12ea 100644 (file)
@@ -190,7 +190,7 @@ void setupLuaLoadBalancingContext(LuaContext& luaCtx);
  * returns: -1 if type wasn't compatible, 0 if not found or number of element(s) found
  */
 template<class G, class T, class V>
-static inline int getOptionalValue(boost::optional<V>& vars, const std::string& key, T& value) {
+static inline int getOptionalValue(boost::optional<V>& vars, const std::string& key, T& value, bool warnOnWrongType = true) {
   /* nothing found, nothing to return */
   if (!vars) {
     return 0;
@@ -201,12 +201,32 @@ static inline int getOptionalValue(boost::optional<V>& vars, const std::string&
       value = boost::get<G>((*vars)[key]);
     } catch (const boost::bad_get& e) {
       /* key is there but isn't compatible */
+      if (warnOnWrongType) {
+        warnlog("Invalid type for key '%s' - ignored", key);
+        vars->erase(key);
+      }
       return -1;
     }
   }
   return vars->erase(key);
 }
 
+template<class T, class V>
+static inline int getOptionalIntegerValue(const std::string& func, boost::optional<V>& vars, const std::string& key, T& value) {
+  std::string valueStr;
+  auto ret = getOptionalValue<std::string>(vars, key, valueStr, true);
+  if (ret == 1) {
+    try {
+      value = std::stoi(valueStr);
+    }
+    catch (const std::exception& e) {
+      warnlog("Parameter '%s' of '%s' must be integer, not '%s' - ignoring", func, key, valueStr);
+      return -1;
+    }
+  }
+  return ret;
+}
+
 template<class V>
 static inline void checkAllParametersConsumed(const std::string& func, const boost::optional<V>& vars) {
   /* no vars */
index fbe5a42320e817d08b3f98ae808e3c44b56b7bb3..9d45803fcabec7cbb9b88d12811de67f190ffc90 100644 (file)
@@ -735,30 +735,3 @@ class TestRoutingHighValueWRandom(DNSDistTest):
             self.assertEqual(self._responsesCounter['UDP Responder 2'], numberOfQueries - self._responsesCounter['UDP Responder'])
         if 'TCP Responder 2' in self._responsesCounter:
             self.assertEqual(self._responsesCounter['TCP Responder 2'], numberOfQueries - self._responsesCounter['TCP Responder'])
-
-class TestRoutingBadWeightWRandom(DNSDistTest):
-
-    _testServer2Port = 5351
-    _consoleKey = DNSDistTest.generateConsoleKey()
-    _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
-    _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_testServer2Port']
-    _config_template = """
-    setKey("%s")
-    controlSocket("127.0.0.1:%s")
-    setServerPolicy(wrandom)
-    s1 = newServer{address="127.0.0.1:%s", weight=-1}
-    s2 = newServer{address="127.0.0.1:%s", weight=2147483648}
-    """
-    _checkConfigExpectedOutput = b"""Error creating new server: downstream weight value must be greater than 0.
-Error creating new server: downstream weight value must be between 1 and 2147483647
-Configuration 'configs/dnsdist_TestRoutingBadWeightWRandom.conf' OK!
-"""
-
-    def testBadWeightWRandom(self):
-        """
-        Routing: WRandom
-
-        Test that downstreams cannot be added with invalid weights.
-        """
-        # There should be no downstreams
-        self.assertTrue(self.sendConsoleCommand("getServer(0)").startswith("Error"))