]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Warn on unsupported parameters
authorAki Tuomi <cmouse@cmouse.fi>
Mon, 22 Feb 2021 12:07:28 +0000 (14:07 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 26 Jan 2023 15:32:16 +0000 (16:32 +0100)
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-rules.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-lua.hh
pdns/dnsdistdist/dnsdist-lua-bindings-packetcache.cc
pdns/dnsdistdist/dnsdist-lua-bindings-protobuf.cc

index fb0695c376e8e6ce14a84b65ef553c20e58edc01..245032fdfbb4173c0c7bece55a9aeac9da5b1c28 100644 (file)
@@ -31,7 +31,6 @@
 #include "dnsdist-kvs.hh"
 #include "dnsdist-svc.hh"
 
-#include "dolog.hh"
 #include "dnstap.hh"
 #include "dnswriter.hh"
 #include "ednsoptions.hh"
@@ -2100,6 +2099,7 @@ static void addAction(GlobalStateHolder<vector<T> > *someRuleActions, const luad
   boost::uuids::uuid uuid;
   uint64_t creationOrder;
   parseRuleParams(params, uuid, name, creationOrder);
+  checkAllParametersConsumed("addAction", params);
 
   auto rule = makeRule(var);
   someRuleActions->modify([&rule, &action, &uuid, creationOrder, &name](vector<T>& ruleactions){
@@ -2109,22 +2109,12 @@ static void addAction(GlobalStateHolder<vector<T> > *someRuleActions, const luad
 
 typedef std::unordered_map<std::string, boost::variant<bool, uint32_t> > responseParams_t;
 
-static void parseResponseConfig(boost::optional<responseParams_t> vars, ResponseConfig& config)
+static void parseResponseConfig(boost::optional<responseParams_t>& vars, ResponseConfig& config)
 {
-  if (vars) {
-    if (vars->count("ttl")) {
-      config.ttl = boost::get<uint32_t>((*vars)["ttl"]);
-    }
-    if (vars->count("aa")) {
-      config.setAA = boost::get<bool>((*vars)["aa"]);
-    }
-    if (vars->count("ad")) {
-      config.setAD = boost::get<bool>((*vars)["ad"]);
-    }
-    if (vars->count("ra")) {
-      config.setRA = boost::get<bool>((*vars)["ra"]);
-    }
-  }
+  getOptionalValue<uint32_t>(vars, "ttl", config.ttl);
+  getOptionalValue<bool>(vars, "aa", config.setAA);
+  getOptionalValue<bool>(vars, "ad", config.setAD);
+  getOptionalValue<bool>(vars, "ra", config.setRA);
 }
 
 void setResponseHeadersFromConfig(dnsheader& dh, const ResponseConfig& config)
@@ -2153,6 +2143,7 @@ void setupLuaActions(LuaContext& luaCtx)
       uint64_t creationOrder;
       std::string name;
       parseRuleParams(params, uuid, name, creationOrder);
+      checkAllParametersConsumed("newRuleAction", params);
 
       auto rule = makeRule(dnsrule);
       DNSDistRuleAction ra({std::move(rule), action, std::move(name), uuid, creationOrder});
@@ -2277,6 +2268,7 @@ void setupLuaActions(LuaContext& luaCtx)
       auto ret = std::shared_ptr<DNSAction>(new SpoofAction(addrs));
       auto sa = std::dynamic_pointer_cast<SpoofAction>(ret);
       parseResponseConfig(vars, sa->d_responseConfig);
+      checkAllParametersConsumed("SpoofAction", vars);
       return ret;
     });
 
@@ -2291,6 +2283,7 @@ void setupLuaActions(LuaContext& luaCtx)
       auto ret = std::shared_ptr<DNSAction>(new SpoofAction(DNSName(a)));
       auto sa = std::dynamic_pointer_cast<SpoofAction>(ret);
       parseResponseConfig(vars, sa->d_responseConfig);
+      checkAllParametersConsumed("SpoofCNAMEAction", vars);
       return ret;
     });
 
@@ -2308,6 +2301,7 @@ void setupLuaActions(LuaContext& luaCtx)
       auto ret = std::shared_ptr<DNSAction>(new SpoofAction(raws));
       auto sa = std::dynamic_pointer_cast<SpoofAction>(ret);
       parseResponseConfig(vars, sa->d_responseConfig);
+      checkAllParametersConsumed("SpoofRawAction", vars);
       return ret;
     });
 
@@ -2402,6 +2396,7 @@ void setupLuaActions(LuaContext& luaCtx)
       auto ret = std::shared_ptr<DNSAction>(new RCodeAction(rcode));
       auto rca = std::dynamic_pointer_cast<RCodeAction>(ret);
       parseResponseConfig(vars, rca->d_responseConfig);
+      checkAllParametersConsumed("RCodeAction", vars);
       return ret;
     });
 
@@ -2409,6 +2404,7 @@ void setupLuaActions(LuaContext& luaCtx)
       auto ret = std::shared_ptr<DNSAction>(new ERCodeAction(rcode));
       auto erca = std::dynamic_pointer_cast<ERCodeAction>(ret);
       parseResponseConfig(vars, erca->d_responseConfig);
+      checkAllParametersConsumed("ERCodeAction", vars);
       return ret;
     });
 
@@ -2464,14 +2460,9 @@ void setupLuaActions(LuaContext& luaCtx)
 
       std::string serverID;
       std::string ipEncryptKey;
-      if (vars) {
-        if (vars->count("serverID")) {
-          serverID = boost::get<std::string>((*vars)["serverID"]);
-        }
-        if (vars->count("ipEncryptKey")) {
-          ipEncryptKey = boost::get<std::string>((*vars)["ipEncryptKey"]);
-        }
-      }
+      getOptionalValue<std::string>(vars, "serverID", serverID);
+      getOptionalValue<std::string>(vars, "ipEncryptKey", ipEncryptKey);
+      checkAllParametersConsumed("RemoteLogAction", vars);
 
       return std::shared_ptr<DNSAction>(new RemoteLogAction(logger, alterFunc, serverID, ipEncryptKey));
     });
@@ -2488,14 +2479,9 @@ void setupLuaActions(LuaContext& luaCtx)
 
       std::string serverID;
       std::string ipEncryptKey;
-      if (vars) {
-        if (vars->count("serverID")) {
-          serverID = boost::get<std::string>((*vars)["serverID"]);
-        }
-        if (vars->count("ipEncryptKey")) {
-          ipEncryptKey = boost::get<std::string>((*vars)["ipEncryptKey"]);
-        }
-      }
+      getOptionalValue<std::string>(vars, "serverID", serverID);
+      getOptionalValue<std::string>(vars, "ipEncryptKey", ipEncryptKey);
+      checkAllParametersConsumed("RemoteLogResponseAction", vars);
 
       return std::shared_ptr<DNSResponseAction>(new RemoteLogResponseAction(logger, alterFunc, serverID, ipEncryptKey, includeCNAME ? *includeCNAME : false));
     });
@@ -2564,6 +2550,7 @@ void setupLuaActions(LuaContext& luaCtx)
       auto ret = std::shared_ptr<DNSAction>(new HTTPStatusAction(status, PacketBuffer(body.begin(), body.end()), contentType ? *contentType : ""));
       auto hsa = std::dynamic_pointer_cast<HTTPStatusAction>(ret);
       parseResponseConfig(vars, hsa->d_responseConfig);
+      checkAllParametersConsumed("HTTPStatusAction", vars);
       return ret;
     });
 #endif /* HAVE_DNS_OVER_HTTPS */
@@ -2580,17 +2567,13 @@ void setupLuaActions(LuaContext& luaCtx)
 
   luaCtx.writeFunction("NegativeAndSOAAction", [](bool nxd, const std::string& zone, uint32_t ttl, const std::string& mname, const std::string& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum, boost::optional<responseParams_t> vars) {
       bool soaInAuthoritySection = false;
-      if (vars) {
-        if (vars->count("soaInAuthoritySection")) {
-          soaInAuthoritySection = boost::get<bool>((*vars)["soaInAuthoritySection"]);
-        }
-      }
-
+      getOptionalValue<bool>(vars, "soaInAuthoritySection", soaInAuthoritySection);
       auto ret = std::shared_ptr<DNSAction>(new NegativeAndSOAAction(nxd, DNSName(zone), ttl, DNSName(mname), DNSName(rname), serial, refresh, retry, expire, minimum, soaInAuthoritySection));
       auto action = std::dynamic_pointer_cast<NegativeAndSOAAction>(ret);
       parseResponseConfig(vars, action->d_responseConfig);
+      checkAllParametersConsumed("NegativeAndSOAAction", vars);
       return ret;
-    });
+  });
 
   luaCtx.writeFunction("SetProxyProtocolValuesAction", [](const std::vector<std::pair<uint8_t, std::string>>& values) {
       return std::shared_ptr<DNSAction>(new SetProxyProtocolValuesAction(values));
index 0506492e00339170f94ac9a64673f5bcac44a6da..67740f9e030577fa9c549d3af121d9a3b3a4d4a3 100644 (file)
@@ -73,14 +73,8 @@ void parseRuleParams(boost::optional<luaruleparams_t> params, boost::uuids::uuid
 
   string uuidStr;
 
-  if (params) {
-    if (params->count("uuid")) {
-      uuidStr = params->at("uuid");
-    }
-    if (params->count("name")) {
-      name = params->at("name");
-    }
-  }
+  getOptionalValue<std::string>(params, "uuid", uuidStr);
+  getOptionalValue<std::string>(params, "name", name);
 
   uuid = makeRuleID(uuidStr);
   creationOrder = s_creationOrder++;
@@ -96,14 +90,9 @@ static std::string rulesToString(const std::vector<T>& rules, boost::optional<ru
   size_t truncateRuleWidth = string::npos;
   std::string result;
 
-  if (vars) {
-    if (vars->count("showUUIDs")) {
-      showUUIDs = boost::get<bool>((*vars)["showUUIDs"]);
-    }
-    if (vars->count("truncateRuleWidth")) {
-      truncateRuleWidth = boost::get<int>((*vars)["truncateRuleWidth"]);
-    }
-  }
+  getOptionalValue<bool>(vars, "showUUIDs", showUUIDs);
+  getOptionalValue<int>(vars, "truncateRuleWidth", truncateRuleWidth);
+  checkAllParametersConsumed("rulesToString", vars);
 
   if (showUUIDs) {
     boost::format fmt("%-3d %-30s %-38s %9d %9d %-56s %s\n");
index 2df956334d1850f9bd95bdec694ed3b1e423d358..ba20ce302a219cb3c44e57cbe53bd7becee43a2a 100644 (file)
@@ -107,29 +107,19 @@ void resetLuaSideEffect()
 
 using localbind_t = LuaAssociativeTable<boost::variant<bool, int, std::string, LuaArray<int>, LuaArray<std::string>, LuaAssociativeTable<std::string>>>;
 
-static void parseLocalBindVars(boost::optional<localbind_t> vars, bool& reusePort, int& tcpFastOpenQueueSize, std::string& interface, std::set<int>& cpus, int& tcpListenQueueSize, uint64_t& maxInFlightQueriesPerConnection, uint64_t& tcpMaxConcurrentConnections)
+static void parseLocalBindVars(boost::optional<localbind_t>& vars, bool& reusePort, int& tcpFastOpenQueueSize, std::string& interface, std::set<int>& cpus, int& tcpListenQueueSize, uint64_t& maxInFlightQueriesPerConnection, uint64_t& tcpMaxConcurrentConnections)
 {
   if (vars) {
-    if (vars->count("reusePort")) {
-      reusePort = boost::get<bool>((*vars)["reusePort"]);
-    }
-    if (vars->count("tcpFastOpenQueueSize")) {
-      tcpFastOpenQueueSize = boost::get<int>((*vars)["tcpFastOpenQueueSize"]);
-    }
-    if (vars->count("tcpListenQueueSize")) {
-      tcpListenQueueSize = boost::get<int>((*vars)["tcpListenQueueSize"]);
-    }
-    if (vars->count("maxConcurrentTCPConnections")) {
-      tcpMaxConcurrentConnections = boost::get<int>((*vars)["maxConcurrentTCPConnections"]);
-    }
-    if (vars->count("maxInFlight")) {
-      maxInFlightQueriesPerConnection = boost::get<int>((*vars)["maxInFlight"]);
-    }
-    if (vars->count("interface")) {
-      interface = boost::get<std::string>((*vars)["interface"]);
-    }
-    if (vars->count("cpus")) {
-      for (const auto& cpu : boost::get<LuaArray<int>>((*vars)["cpus"])) {
+    LuaArray<int> setCpus;
+
+    getOptionalValue<bool>(vars, "reusePort", reusePort);
+    getOptionalValue<int>(vars, "tcpFastOpenQueueSize", tcpFastOpenQueueSize);
+    getOptionalValue<int>(vars, "tcpListenQueueSize", tcpListenQueueSize);
+    getOptionalValue<int>(vars, "maxConcurrentTCPConnections", tcpMaxConcurrentConnections);
+    getOptionalValue<int>(vars, "maxInFlight", maxInFlightQueriesPerConnection);
+    getOptionalValue<std::string>(vars, "interface", interface);
+    if (getOptionalValue<decltype(setCpus)>(vars, "cpus", setCpus) > 0) {
+      for (const auto& cpu : setCpus) {
         cpus.insert(cpu.second);
       }
     }
@@ -181,65 +171,45 @@ static bool loadTLSCertificateAndKeys(const std::string& context, std::vector<TL
   return true;
 }
 
-static void parseTLSConfig(TLSConfig& config, const std::string& context, boost::optional<localbind_t> vars)
+static void parseTLSConfig(TLSConfig& config, const std::string& context, boost::optional<localbind_t>& vars)
 {
-  if (vars->count("ciphers")) {
-    config.d_ciphers = boost::get<const string>((*vars)["ciphers"]);
-  }
-
-  if (vars->count("ciphersTLS13")) {
-    config.d_ciphers13 = boost::get<const string>((*vars)["ciphersTLS13"]);
-  }
+  getOptionalValue<std::string>(vars, "ciphers", config.d_ciphers);
+  getOptionalValue<std::string>(vars, "ciphersTLS13", config.d_ciphers13);
 
 #ifdef HAVE_LIBSSL
-  if (vars->count("minTLSVersion")) {
-    config.d_minTLSVersion = libssl_tls_version_from_string(boost::get<const string>((*vars)["minTLSVersion"]));
-  }
+  std::string minVersion;
+  if (getOptionalValue<std::string>(vars, "minTLSVersion", minVersion) > 0)
+    config.d_minTLSVersion = libssl_tls_version_from_string(minVersion);
+#else /* HAVE_LIBSSL */
+  if (vars->erase("minTLSVersion") > 0)
+    warnlog("minTLSVersion has no effect with chosen TLS library");
 #endif /* HAVE_LIBSSL */
 
-  if (vars->count("ticketKeyFile")) {
-    config.d_ticketKeyFile = boost::get<const string>((*vars)["ticketKeyFile"]);
-  }
-
-  if (vars->count("ticketsKeysRotationDelay")) {
-    config.d_ticketsKeyRotationDelay = boost::get<int>((*vars)["ticketsKeysRotationDelay"]);
-  }
-
-  if (vars->count("numberOfTicketsKeys")) {
-    config.d_numberOfTicketsKeys = boost::get<int>((*vars)["numberOfTicketsKeys"]);
-  }
-
-  if (vars->count("preferServerCiphers")) {
-    config.d_preferServerCiphers = boost::get<bool>((*vars)["preferServerCiphers"]);
-  }
-
-  if (vars->count("sessionTimeout")) {
-    config.d_sessionTimeout = boost::get<int>((*vars)["sessionTimeout"]);
-  }
-
-  if (vars->count("sessionTickets")) {
-    config.d_enableTickets = boost::get<bool>((*vars)["sessionTickets"]);
+  getOptionalValue<std::string>(vars, "ticketKeyFile", config.d_ticketKeyFile);
+  getOptionalValue<int>(vars, "ticketsKeysRotationDelay", config.d_ticketsKeyRotationDelay);
+  getOptionalValue<int>(vars, "numberOfTicketsKeys", config.d_numberOfTicketsKeys);
+  getOptionalValue<bool>(vars, "preferServerCiphers", config.d_preferServerCiphers);
+  getOptionalValue<int>(vars, "sessionTimeout", config.d_sessionTimeout);
+  getOptionalValue<bool>(vars, "sessionTickets", config.d_enableTickets);
+  int numberOfStoredSessions{0};
+  if (getOptionalValue<int>(vars, "numberOfStoredSessions", numberOfStoredSessions) > 0) {
+    if (numberOfStoredSessions < 0) {
+      errlog("Invalid value '%d' for %s() parameter 'numberOfStoredSessions', should be >= 0, dismissing", numberOfStoredSessions, context);
+      g_outputBuffer = "Invalid value '" +  std::to_string(numberOfStoredSessions) + "' for " + context + "() parameter 'numberOfStoredSessions', should be >= 0, dimissing";
+    }
+    config.d_maxStoredSessions = numberOfStoredSessions;
   }
 
-  if (vars->count("numberOfStoredSessions")) {
-    auto value = boost::get<int>((*vars)["numberOfStoredSessions"]);
-    if (value < 0) {
-      errlog("Invalid value '%d' for %s() parameter 'numberOfStoredSessions', should be >= 0, dismissing", value, context);
-      g_outputBuffer = "Invalid value '" + std::to_string(value) + "' for " + context + "() parameter 'numberOfStoredSessions', should be >= 0, dismissing";
-    }
-    config.d_maxStoredSessions = value;
-  }
-
-  if (vars->count("ocspResponses")) {
-    auto files = boost::get<LuaArray<std::string>>((*vars)["ocspResponses"]);
+  LuaArray<std::string> files;
+  if (getOptionalValue<decltype(files)>(vars, "ocspResponses", files) > 0) {
     for (const auto& file : files) {
       config.d_ocspFiles.push_back(file.second);
     }
   }
 
-  if (vars->count("keyLogFile")) {
+  if (vars->count("keyLogFile") > 0) {
 #ifdef HAVE_SSL_CTX_SET_KEYLOG_CALLBACK
-    config.d_keyLogFile = boost::get<const string>((*vars)["keyLogFile"]);
+    getOptionalValue<std::string>(vars, "keyLogFile", config.d_keyLogFile);
 #else
     errlog("TLS Key logging has been enabled using the 'keyLogFile' parameter to %s(), but this version of OpenSSL does not support it", context);
     g_outputBuffer = "TLS Key logging has been enabled using the 'keyLogFile' parameter to " + context + "(), but this version of OpenSSL does not support it";
@@ -349,6 +319,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
                            serverAddressStr = boost::get<string>(vars["address"]);
                          }
 
+                         // FIXME: Check vars for unknown keys, needs refactoring to move creation at the end */
+
                          if (vars.count("source")) {
                            /* handle source in the following forms:
                               - v4 address ("192.0.2.1")
@@ -374,7 +346,6 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
                              /* try to parse as interface name, or v4/v6@itf */
                              config.sourceItfName = source.substr(pos == std::string::npos ? 0 : pos + 1);
                              unsigned int itfIdx = if_nametoindex(config.sourceItfName.c_str());
-
                              if (itfIdx != 0) {
                                if (pos == 0 || pos == std::string::npos) {
                                  /* "eth0" or "@eth0" */
@@ -848,6 +819,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
 
     parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConnections);
 
+    checkAllParametersConsumed("setLocal", vars);
+
     try {
       ComboAddress loc(addr, 53);
       for (auto it = g_frontends.begin(); it != g_frontends.end();) {
@@ -897,6 +870,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     std::set<int> cpus;
 
     parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConnections);
+    checkAllParametersConsumed("addLocal", vars);
 
     try {
       ComboAddress loc(addr, 53);
@@ -992,11 +966,9 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
   luaCtx.writeFunction("showServers", [](boost::optional<showserversopts_t> vars) {
     setLuaNoSideEffect();
     bool showUUIDs = false;
-    if (vars) {
-      if (vars->count("showUUIDs")) {
-        showUUIDs = boost::get<bool>((*vars)["showUUIDs"]);
-      }
-    }
+    getOptionalValue<bool>(vars, "showUUIDs", showUUIDs);
+    checkAllParametersConsumed("showServers", vars);
+
     try {
       ostringstream ret;
       boost::format fmt;
@@ -1150,8 +1122,16 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       hashPlaintextCredentials = boost::get<bool>(vars->at("hashPlaintextCredentials"));
     }
 
-    if (vars->count("password")) {
-      std::string password = boost::get<std::string>(vars->at("password"));
+    std::string password;
+    std::string apiKey;
+    std::string acl;
+    LuaAssociativeTable<std::string> headers;
+    bool statsRequireAuthentication{true};
+    bool apiRequiresAuthentication{true};
+    bool dashboardRequiresAuthentication{true};
+    std::string maxConcurrentConnections;
+
+    if (getOptionalValue<std::string>(vars, "password", password) > 0) {
       auto holder = make_unique<CredentialsHolder>(std::move(password), hashPlaintextCredentials);
       if (!holder->wasHashed() && holder->isHashingAvailable()) {
         infolog("Passing a plain-text password via the 'password' parameter to 'setWebserverConfig()' is not advised, please consider generating a hashed one using 'hashPassword()' instead.");
@@ -1160,8 +1140,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       setWebserverPassword(std::move(holder));
     }
 
-    if (vars->count("apiKey")) {
-      std::string apiKey = boost::get<std::string>(vars->at("apiKey"));
+    if (getOptionalValue<std::string>(vars, "apiKey", apiKey) > 0) {
       auto holder = make_unique<CredentialsHolder>(std::move(apiKey), hashPlaintextCredentials);
       if (!holder->wasHashed() && holder->isHashingAvailable()) {
         infolog("Passing a plain-text API key via the 'apiKey' parameter to 'setWebserverConfig()' is not advised, please consider generating a hashed one using 'hashPassword()' instead.");
@@ -1170,32 +1149,28 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       setWebserverAPIKey(std::move(holder));
     }
 
-    if (vars->count("acl")) {
-      const std::string acl = boost::get<std::string>(vars->at("acl"));
-
+    if (getOptionalValue<std::string>(vars, "acl", acl) > 0) {
       setWebserverACL(acl);
     }
 
-    if (vars->count("customHeaders")) {
-      const auto headers = boost::get<std::unordered_map<std::string, std::string>>(vars->at("customHeaders"));
-
+    if (getOptionalValue<decltype(headers)>(vars, "customHeaders", headers) > 0) {
       setWebserverCustomHeaders(headers);
     }
 
-    if (vars->count("statsRequireAuthentication")) {
-      setWebserverStatsRequireAuthentication(boost::get<bool>(vars->at("statsRequireAuthentication")));
+    if (getOptionalValue<bool>(vars, "statsRequireAuthentication", statsRequireAuthentication) > 0) {
+      setWebserverStatsRequireAuthentication(statsRequireAuthentication);
     }
 
-    if (vars->count("apiRequiresAuthentication")) {
-      setWebserverAPIRequiresAuthentication(boost::get<bool>(vars->at("apiRequiresAuthentication")));
+    if (getOptionalValue<bool>(vars, "apiRequiresAuthentication", apiRequiresAuthentication) > 0) {
+      setWebserverAPIRequiresAuthentication(apiRequiresAuthentication);
     }
 
-    if (vars->count("dashboardRequiresAuthentication")) {
-      setWebserverDashboardRequiresAuthentication(boost::get<bool>(vars->at("dashboardRequiresAuthentication")));
+    if (getOptionalValue<bool>(vars, "dashboardRequiresAuthentication", dashboardRequiresAuthentication) > 0) {
+      setWebserverDashboardRequiresAuthentication(dashboardRequiresAuthentication);
     }
 
-    if (vars->count("maxConcurrentConnections")) {
-      setWebserverMaxConcurrentConnections(std::stoi(boost::get<std::string>(vars->at("maxConcurrentConnections"))));
+    if (getOptionalValue<std::string>(vars, "maxConcurrentConnections", maxConcurrentConnections) > 0) {
+      setWebserverMaxConcurrentConnections(std::stoi(maxConcurrentConnections));
     }
   });
 
@@ -1675,6 +1650,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     std::vector<DNSCryptContext::CertKeyPaths> certKeys;
 
     parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConnections);
+    checkAllParametersConsumed("addDNSCryptBind", vars);
 
     if (certFiles.type() == typeid(std::string) && keyFiles.type() == typeid(std::string)) {
       auto certFile = boost::get<std::string>(certFiles);
@@ -2524,43 +2500,25 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
 
     if (vars) {
       parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConnections);
-
-      if (vars->count("idleTimeout")) {
-        frontend->d_idleTimeout = boost::get<int>((*vars)["idleTimeout"]);
-      }
-
-      if (vars->count("serverTokens")) {
-        frontend->d_serverTokens = boost::get<const string>((*vars)["serverTokens"]);
-      }
-
-      if (vars->count("customResponseHeaders")) {
-        for (auto const& headerMap : boost::get<LuaAssociativeTable<std::string>>((*vars).at("customResponseHeaders"))) {
-          frontend->d_customResponseHeaders[boost::to_lower_copy(headerMap.first)] = headerMap.second;
+      getOptionalValue<int>(vars, "idleTimeout", frontend->d_idleTimeout);
+      getOptionalValue<std::string>(vars, "serverTokens", frontend->d_serverTokens);
+
+      LuaAssociativeTable<std::string> customResponseHeaders;
+      if (getOptionalValue<decltype(customResponseHeaders)>(vars, "customResponseHeaders", customResponseHeaders) > 0) {
+        for (auto const& headerMap : customResponseHeaders) {
+          std::pair<std::string,std::string> headerResponse = std::make_pair(boost::to_lower_copy(headerMap.first), headerMap.second);
+          frontend->d_customResponseHeaders.insert(headerResponse);
         }
       }
 
-      if (vars->count("sendCacheControlHeaders")) {
-        frontend->d_sendCacheControlHeaders = boost::get<bool>((*vars)["sendCacheControlHeaders"]);
-      }
-
-      if (vars->count("keepIncomingHeaders")) {
-        frontend->d_keepIncomingHeaders = boost::get<bool>((*vars)["keepIncomingHeaders"]);
-      }
-
-      if (vars->count("trustForwardedForHeader")) {
-        frontend->d_trustForwardedForHeader = boost::get<bool>((*vars)["trustForwardedForHeader"]);
-      }
+      getOptionalValue<bool>(vars, "sendCacheControlHeaders", frontend->d_sendCacheControlHeaders);
+      getOptionalValue<bool>(vars, "keepIncomingHeaders", frontend->d_keepIncomingHeaders);
+      getOptionalValue<bool>(vars, "trustForwardedForHeader", frontend->d_trustForwardedForHeader);
+      getOptionalValue<int>(vars, "internalPipeBufferSize", frontend->d_internalPipeBufferSize);
+      getOptionalValue<bool>(vars, "exactPathMatching", frontend->d_exactPathMatching);
 
-      if (vars->count("internalPipeBufferSize")) {
-        frontend->d_internalPipeBufferSize = boost::get<int>((*vars)["internalPipeBufferSize"]);
-      }
-
-      if (vars->count("exactPathMatching")) {
-        frontend->d_exactPathMatching = boost::get<bool>((*vars)["exactPathMatching"]);
-      }
-
-      if (vars->count("additionalAddresses")) {
-        auto addresses = boost::get<LuaArray<std::string>>(vars->at("additionalAddresses"));
+      LuaArray<std::string> addresses;
+      if (getOptionalValue<decltype(addresses)>(vars, "additionalAddresses", addresses) > 0) {
         for (const auto& [_, add] : addresses) {
           try {
             ComboAddress address(add);
@@ -2574,20 +2532,22 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       }
 
       parseTLSConfig(frontend->d_tlsConfig, "addDOHLocal", vars);
-      if (vars->count("ignoreTLSConfigurationErrors")) {
-        if (boost::get<bool>((*vars)["ignoreTLSConfigurationErrors"])) {
-          // we are asked to try to load the certificates so we can return a potential error
-          // and properly ignore the frontend before actually launching it
-          try {
-            std::map<int, std::string> ocspResponses = {};
-            auto ctx = libssl_init_server_context(frontend->d_tlsConfig, ocspResponses);
-          }
-          catch (const std::runtime_error& e) {
-            errlog("Ignoring DoH frontend: '%s'", e.what());
-            return;
-          }
+
+      bool ignoreTLSConfigurationErrors = false;
+      if (getOptionalValue<bool>(vars, "ignoreTLSConfigurationErrors", ignoreTLSConfigurationErrors) > 0 && ignoreTLSConfigurationErrors) {
+        // we are asked to try to load the certificates so we can return a potential error
+        // and properly ignore the frontend before actually launching it
+        try {
+          std::map<int, std::string> ocspResponses = {};
+          auto ctx = libssl_init_server_context(frontend->d_tlsConfig, ocspResponses);
+        }
+        catch (const std::runtime_error& e) {
+          errlog("Ignoring DoH frontend: '%s'", e.what());
+          return;
         }
       }
+
+      checkAllParametersConsumed("addDOHLocal", vars);
     }
     g_dohlocals.push_back(frontend);
     auto cs = std::make_unique<ClientState>(frontend->d_local, true, reusePort, tcpFastOpenQueueSize, interface, cpus);
@@ -2735,13 +2695,6 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     }
   });
 
-  luaCtx.registerFunction<std::string (std::shared_ptr<DOHFrontend>::*)() const>("getAddressAndPort", [](const std::shared_ptr<DOHFrontend>& frontend) {
-    if (frontend == nullptr) {
-      return std::string();
-    }
-    return frontend->d_local.toStringWithPort();
-  });
-
   luaCtx.writeFunction("addTLSLocal", [client](const std::string& addr, boost::variant<std::string, std::shared_ptr<TLSCertKeyPair>, LuaArray<std::string>, LuaArray<std::shared_ptr<TLSCertKeyPair>>> certFiles, LuaTypeOrArrayOf<std::string> keyFiles, boost::optional<localbind_t> vars) {
     if (client) {
       return;
@@ -2770,13 +2723,11 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     if (vars) {
       parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConns);
 
-      if (vars->count("provider")) {
-        frontend->d_provider = boost::get<const string>((*vars)["provider"]);
-        boost::algorithm::to_lower(frontend->d_provider);
-      }
+      getOptionalValue<std::string>(vars, "provider", frontend->d_provider);
+      boost::algorithm::to_lower(frontend->d_provider);
 
-      if (vars->count("additionalAddresses")) {
-        auto addresses = boost::get<LuaArray<std::string>>(vars->at("additionalAddresses"));
+      LuaArray<std::string> addresses;
+      if (getOptionalValue<decltype(addresses)>(vars, "additionalAddresses", addresses) > 0) {
         for (const auto& [_, add] : addresses) {
           try {
             ComboAddress address(add);
@@ -2790,20 +2741,22 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       }
 
       parseTLSConfig(frontend->d_tlsConfig, "addTLSLocal", vars);
-      if (vars->count("ignoreTLSConfigurationErrors")) {
-        if (boost::get<bool>((*vars)["ignoreTLSConfigurationErrors"])) {
-          // we are asked to try to load the certificates so we can return a potential error
-          // and properly ignore the frontend before actually launching it
-          try {
-            std::map<int, std::string> ocspResponses = {};
-            auto ctx = libssl_init_server_context(frontend->d_tlsConfig, ocspResponses);
-          }
-          catch (const std::runtime_error& e) {
-            errlog("Ignoring TLS frontend: '%s'", e.what());
-            return;
-          }
+
+      bool ignoreTLSConfigurationErrors = false;
+      if (getOptionalValue<bool>(vars, "ignoreTLSConfigurationErrors", ignoreTLSConfigurationErrors) > 0 && ignoreTLSConfigurationErrors) {
+        // we are asked to try to load the certificates so we can return a potential error
+        // and properly ignore the frontend before actually launching it
+        try {
+          std::map<int, std::string> ocspResponses = {};
+          auto ctx = libssl_init_server_context(frontend->d_tlsConfig, ocspResponses);
+        }
+        catch (const std::runtime_error& e) {
+          errlog("Ignoring TLS frontend: '%s'", e.what());
+          return;
         }
       }
+
+      checkAllParametersConsumed("addTLSLocal", vars);
     }
 
     try {
index 9570f3e80251d68d144204e9332c94d4fd5c683a..a477bd5e06c34b929d8a1d99b12b2c0f7b11d4bf 100644 (file)
@@ -21,6 +21,7 @@
  */
 #pragma once
 
+#include "dolog.hh"
 #include "dnsdist.hh"
 #include "dnsparser.hh"
 #include <random>
@@ -179,3 +180,40 @@ void setupLuaInspection(LuaContext& luaCtx);
 void setupLuaVars(LuaContext& luaCtx);
 void setupLuaWeb(LuaContext& luaCtx);
 void setupLuaLoadBalancingContext(LuaContext& luaCtx);
+
+/**
+ * getOptionalValue(vars, key, value)
+ *
+ * Attempts to extract value for key in vars.
+ * Erases the key from vars.
+ *
+ * returns: -1 if type wasn't compatible, 0 if not found or number of element(s) found
+ */
+template<class G, class T, class V>
+static inline int getOptionalValue(boost::optional<V>& vars, const std::string& key, T& value) {
+  /* nothing found, nothing to return */
+  if (!vars) {
+    return 0;
+  }
+
+  if (vars->count(key)) {
+    try {
+      value = boost::get<G>((*vars)[key]);
+    } catch (const boost::bad_get& e) {
+      /* key is there but isn't compatible */
+      return -1;
+    }
+  }
+  return vars->erase(key);
+}
+
+template<class V>
+static inline void checkAllParametersConsumed(const std::string& func, const boost::optional<V>& vars) {
+  /* no vars */
+  if (!vars) {
+    return;
+  }
+  for (const auto& [key, value] : *vars) {
+    warnlog("%s: Unknown key '%s' given - ignored", func, key);
+  }
+}
index 184ff33c2874bc02ebf4a84fe77f82c682d9c9be..fd62eb5318ad784aefdd89a29b96c1f1fba83e2b 100644 (file)
@@ -24,7 +24,6 @@
 #include <sys/types.h>
 
 #include "config.h"
-#include "dolog.hh"
 #include "dnsdist.hh"
 #include "dnsdist-lua.hh"
 
@@ -35,96 +34,62 @@ void setupLuaBindingsPacketCache(LuaContext& luaCtx, bool client)
   /* PacketCache */
   luaCtx.writeFunction("newPacketCache", [client](size_t maxEntries, boost::optional<LuaAssociativeTable<boost::variant<bool, size_t, LuaArray<uint16_t>>>> vars) {
 
-      bool keepStaleData = false;
-      size_t maxTTL = 86400;
-      size_t minTTL = 0;
-      size_t tempFailTTL = 60;
-      size_t maxNegativeTTL = 3600;
-      size_t staleTTL = 60;
-      size_t numberOfShards = 20;
-      bool dontAge = false;
-      bool deferrableInsertLock = true;
-      bool ecsParsing = false;
-      std::unordered_set<uint16_t> optionsToSkip{EDNSOptionCode::COOKIE};
-
-      if (vars) {
-
-        if (vars->count("deferrableInsertLock")) {
-          deferrableInsertLock = boost::get<bool>((*vars)["deferrableInsertLock"]);
-        }
-
-        if (vars->count("dontAge")) {
-          dontAge = boost::get<bool>((*vars)["dontAge"]);
-        }
-
-        if (vars->count("keepStaleData")) {
-          keepStaleData = boost::get<bool>((*vars)["keepStaleData"]);
-        }
-
-        if (vars->count("maxNegativeTTL")) {
-          maxNegativeTTL = boost::get<size_t>((*vars)["maxNegativeTTL"]);
-        }
-
-        if (vars->count("maxTTL")) {
-          maxTTL = boost::get<size_t>((*vars)["maxTTL"]);
-        }
-
-        if (vars->count("minTTL")) {
-          minTTL = boost::get<size_t>((*vars)["minTTL"]);
-        }
-
-        if (vars->count("numberOfShards")) {
-          numberOfShards = boost::get<size_t>((*vars)["numberOfShards"]);
-        }
-
-        if (vars->count("parseECS")) {
-          ecsParsing = boost::get<bool>((*vars)["parseECS"]);
-        }
-
-        if (vars->count("staleTTL")) {
-          staleTTL = boost::get<size_t>((*vars)["staleTTL"]);
-        }
-
-        if (vars->count("temporaryFailureTTL")) {
-          tempFailTTL = boost::get<size_t>((*vars)["temporaryFailureTTL"]);
-        }
-
-        if (vars->count("cookieHashing")) {
-          if (boost::get<bool>((*vars)["cookieHashing"])) {
-            optionsToSkip.erase(EDNSOptionCode::COOKIE);
-          }
-        }
-        if (vars->count("skipOptions")) {
-          for (const auto& option: boost::get<LuaArray<uint16_t>>(vars->at("skipOptions"))) {
-            optionsToSkip.insert(option.second);
-          }
-        }
+    bool keepStaleData = false;
+    size_t maxTTL = 86400;
+    size_t minTTL = 0;
+    size_t tempFailTTL = 60;
+    size_t maxNegativeTTL = 3600;
+    size_t staleTTL = 60;
+    size_t numberOfShards = 20;
+    bool dontAge = false;
+    bool deferrableInsertLock = true;
+    bool ecsParsing = false;
+    bool cookieHashing = false;
+    LuaArray<uint16_t> skipOptions;
+    std::unordered_set<uint16_t> optionsToSkip{EDNSOptionCode::COOKIE};
+
+    getOptionalValue<bool>(vars, "deferrableInsertLock", deferrableInsertLock);
+    getOptionalValue<bool>(vars, "dontAge", dontAge);
+    getOptionalValue<bool>(vars, "keepStaleData", keepStaleData);
+    getOptionalValue<size_t>(vars, "maxNegativeTTL", maxNegativeTTL);
+    getOptionalValue<size_t>(vars, "maxTTL", maxTTL);
+    getOptionalValue<size_t>(vars, "minTTL", minTTL);
+    getOptionalValue<size_t>(vars, "numberOfShards", numberOfShards);
+    getOptionalValue<bool>(vars, "parseECS", ecsParsing);
+    getOptionalValue<size_t>(vars, "staleTTL", staleTTL);
+    getOptionalValue<size_t>(vars, "temporaryFailureTTL", tempFailTTL);
+    getOptionalValue<bool>(vars, "cookieHashing", cookieHashing);
+
+    if (getOptionalValue<decltype(skipOptions)>(vars, "skipOptions", skipOptions) > 0) {
+      for (const auto& option : skipOptions) {
+        optionsToSkip.insert(option.second);
       }
+    }
 
-      if (maxEntries == 0) {
-        warnlog("The number of entries in the packet cache is set to 0, raising to 1");
-        g_outputBuffer += "The number of entries in the packet cache is set to 0, raising to 1";
-        maxEntries = 1;
-      }
+    if (cookieHashing) {
+      optionsToSkip.erase(EDNSOptionCode::COOKIE);
+    }
 
-      if (maxEntries < numberOfShards) {
-        warnlog("The number of entries (%d) in the packet cache is smaller than the number of shards (%d), decreasing the number of shards to %d", maxEntries, numberOfShards, maxEntries);
-        g_outputBuffer += "The number of entries (" + std::to_string(maxEntries) + " in the packet cache is smaller than the number of shards (" + std::to_string(numberOfShards) + "), decreasing the number of shards to " + std::to_string(maxEntries);
-        numberOfShards = maxEntries;
-      }
+    checkAllParametersConsumed("newPacketCache", vars);
 
-      if (client) {
-        maxEntries = 1;
-        numberOfShards = 1;
-      }
+    if (maxEntries < numberOfShards) {
+      warnlog("The number of entries (%d) in the packet cache is smaller than the number of shards (%d), decreasing the number of shards to %d", maxEntries, numberOfShards, maxEntries);
+      g_outputBuffer += "The number of entries (" + std::to_string(maxEntries) + " in the packet cache is smaller than the number of shards (" + std::to_string(numberOfShards) + "), decreasing the number of shards to " + std::to_string(maxEntries);
+      numberOfShards = maxEntries;
+    }
 
-      auto res = std::make_shared<DNSDistPacketCache>(maxEntries, maxTTL, minTTL, tempFailTTL, maxNegativeTTL, staleTTL, dontAge, numberOfShards, deferrableInsertLock, ecsParsing);
+    if (client) {
+      maxEntries = 1;
+      numberOfShards = 1;
+    }
 
-      res->setKeepStaleData(keepStaleData);
-      res->setSkippedOptions(optionsToSkip);
+    auto res = std::make_shared<DNSDistPacketCache>(maxEntries, maxTTL, minTTL, tempFailTTL, maxNegativeTTL, staleTTL, dontAge, numberOfShards, deferrableInsertLock, ecsParsing);
 
-      return res;
-    });
+    res->setKeepStaleData(keepStaleData);
+    res->setSkippedOptions(optionsToSkip);
+
+    return res;
+  });
 
 #ifndef DISABLE_PACKETCACHE_BINDINGS
   luaCtx.registerFunction<std::string(std::shared_ptr<DNSDistPacketCache>::*)()const>("toString", [](const std::shared_ptr<DNSDistPacketCache>& cache) {
index 6e470d34e8b7292f754d5ad7c8f3cf7906c2422f..e532a568280f620622b13020a5f0a7e83d5744da 100644 (file)
@@ -32,7 +32,7 @@
 #include "remote_logger.hh"
 
 #ifdef HAVE_FSTRM
-static void parseFSTRMOptions(const boost::optional<LuaAssociativeTable<unsigned int>>& params, LuaAssociativeTable<unsigned int>& options)
+static void parseFSTRMOptions(boost::optional<LuaAssociativeTable<unsigned int>>& params, LuaAssociativeTable<unsigned int>& options)
 {
   if (!params) {
     return;
@@ -41,9 +41,7 @@ static void parseFSTRMOptions(const boost::optional<LuaAssociativeTable<unsigned
   static std::vector<std::string> const potentialOptions = { "bufferHint", "flushTimeout", "inputQueueSize", "outputQueueSize", "queueNotifyThreshold", "reopenInterval" };
 
   for (const auto& potentialOption : potentialOptions) {
-    if (params->count(potentialOption)) {
-      options[potentialOption] = boost::get<unsigned int>(params->at(potentialOption));
-    }
+    getOptionalValue<unsigned int>(params, potentialOption, options[potentialOption]);
   }
 }
 #endif /* HAVE_FSTRM */
@@ -138,6 +136,7 @@ void setupLuaBindingsProtoBuf(LuaContext& luaCtx, bool client, bool configCheck)
 
       LuaAssociativeTable<unsigned int> options;
       parseFSTRMOptions(params, options);
+      checkAllParametersConsumed("newRemoteLogger", params);
       return std::shared_ptr<RemoteLoggerInterface>(new FrameStreamLogger(AF_UNIX, address, !client, options));
 #else
       throw std::runtime_error("fstrm support is required to build an AF_UNIX FrameStreamLogger");
@@ -152,6 +151,7 @@ void setupLuaBindingsProtoBuf(LuaContext& luaCtx, bool client, bool configCheck)
 
       LuaAssociativeTable<unsigned int> options;
       parseFSTRMOptions(params, options);
+      checkAllParametersConsumed("newFrameStreamTcpLogger", params);
       return std::shared_ptr<RemoteLoggerInterface>(new FrameStreamLogger(AF_INET, address, !client, options));
 #else
       throw std::runtime_error("fstrm with TCP support is required to build an AF_INET FrameStreamLogger");