]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: replace specific MMDB queries with a single generic query
authorEnsar Sarajčić <dev@ensarsarajcic.com>
Mon, 22 Dec 2025 13:33:58 +0000 (14:33 +0100)
committerEnsar Sarajčić <dev@ensarsarajcic.com>
Tue, 9 Jun 2026 07:54:51 +0000 (09:54 +0200)
Signed-off-by: Ensar Sarajčić <dev@ensarsarajcic.com>
pdns/dnsdistdist/dnsdist-carbon.cc
pdns/dnsdistdist/dnsdist-configuration.hh
pdns/dnsdistdist/dnsdist-kvs.cc
pdns/dnsdistdist/dnsdist-kvs.hh
pdns/dnsdistdist/dnsdist-lua-bindings-kvs.cc
pdns/dnsdistdist/dnsdist-lua-bindings-mmdb.cc
pdns/dnsdistdist/dnsdist-lua-types.hh [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-lua.hh
pdns/dnsdistdist/dnsdist-web.cc
pdns/dnsdistdist/mmdb.cc
pdns/dnsdistdist/mmdb.hh

index 60c6f894f3105810aac27509c666dee7146e509c..d959a5fa500ba04c03715bde547743e1aab1b37a 100644 (file)
@@ -19,6 +19,7 @@
  * along with this program; if not, write to the Free Software
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
+#include "dnsdist-lua-types.hh"
 #ifdef HAVE_CONFIG_H
 #include "config.h"
 #endif
index 6ef8cc7d7c9c35a0e70579fb5bec6511ced42496..a4a2c25c81c4f86f9086a316cc9ab3293155c015 100644 (file)
@@ -32,6 +32,7 @@
 #include "credentials.hh"
 #include "dnsdist-actions.hh"
 #include "dnsdist-carbon.hh"
+#include "dnsdist-lua-types.hh"
 #include "dnsdist-query-count.hh"
 #include "dnsdist-rule-chains.hh"
 #include "dnsdist-server-pool.hh"
index bfcf155cd5f5c73caebf7cf7e3ad14d306d483c2..4b592fcb26f6f4afea0cf8c64190bc0d8c998940 100644 (file)
  */
 
 #include "dnsdist-kvs.hh"
+#include "dnsdist-lua-types.hh"
 #include "dolog.hh"
 
+#include <limits.h>
 #include <sys/stat.h>
+#include <boost/variant.hpp>
 
 std::vector<std::string> KeyValueLookupKeySourceIP::getKeys(const ComboAddress& addr)
 {
@@ -304,37 +307,87 @@ bool CDBKVStore::keyExists(const std::string& key)
 #endif /* HAVE_CDB */
 
 #ifdef HAVE_MMDB
-bool MMDBKVStore::keyExists(const std::string& key)
+
+#include "ext/json11/json11.hpp"
+
+std::shared_ptr<const Logr::Logger> MMDBKVStore::getLogger() const
 {
-  auto addr = makeComboAddressFromRaw(key.size() == sizeof(in_addr) ? 4 : 6, key);
-  return d_mmdb->exists(addr);
+  return dnsdist::logging::getTopLogger("mmdb-key-value-store")->withValues("path", Logging::Loggable(d_mmdb->file_name()));
 }
 
-bool MMDBKVStore::getValue(const std::string& key, std::string& value)
+json11::Json MMDBKVStore::parseAny(const LuaAny& any)
 {
-  auto address = makeComboAddressFromRaw(key.size() == sizeof(in_addr) ? 4 : 6, key);
-  if (d_field == "country") {
-    return d_mmdb->queryCountry(value, address);
+  if (any.type() == typeid(std::string)) {
+    return json11::Json(boost::get<std::string>(any));
   }
-  else if (d_field == "continent") {
-    return d_mmdb->queryContinent(value, address);
+  else if (any.type() == typeid(int64_t)) {
+    auto val = boost::get<int64_t>(any);
+    if (val > static_cast<int64_t>(INT_MAX) || val < static_cast<int64_t>(INT_MIN)) {
+      SLOG(warnlog("Error while retrieving a value from MMDB database: integer overflow. Returning null."),
+           getLogger()->error(Logr::Warning, "", "Error while retrieving a value from MMDB database: integer overflow. Returning null."));
+      return json11::Json();
+    }
+    else {
+      return json11::Json(static_cast<int>(val));
+    }
   }
-  else if (d_field == "asn") {
-    return d_mmdb->queryAS(value, address);
+  else if (any.type() == typeid(uint64_t)) {
+    auto val = boost::get<uint64_t>(any);
+    if (val > static_cast<uint64_t>(INT_MAX)) {
+      SLOG(warnlog("Error while retrieving a value from MMDB database: integer overflow. Returning null."),
+           getLogger()->error(Logr::Warning, "", "Error while retrieving a value from MMDB database: integer overflow. Returning null."));
+      return json11::Json();
+    }
+    else {
+      return json11::Json(static_cast<int>(val));
+    }
   }
-  else if (d_field == "asnum") {
-    return d_mmdb->queryASN(value, address);
+  else if (any.type() == typeid(double)) {
+    return json11::Json(boost::get<double>(any));
   }
-  else if (d_field == "city") {
-    return d_mmdb->queryCity(value, address, "en");
+  else if (any.type() == typeid(bool)) {
+    return json11::Json(boost::get<bool>(any));
+  }
+  else if (any.type() == typeid(LuaArray<LuaAny>)) {
+    auto luaArray = boost::get<LuaArray<LuaAny>>(any);
+    std::vector<json11::Json> array;
+    array.reserve(luaArray.size());
+    for (auto& kv : luaArray) {
+      array.emplace_back(parseAny(kv.second));
+    }
+    return json11::Json(array);
+  }
+  else if (any.type() == typeid(LuaAssociativeTable<LuaAny>)) {
+    auto luaTable = boost::get<LuaAssociativeTable<LuaAny>>(any);
+    std::unordered_map<std::string, json11::Json> map(luaTable.size());
+    for (auto& kv : luaTable) {
+      map.emplace(kv.first, parseAny(kv.second));
+    }
+    return json11::Json(map);
   }
   else {
-    return false;
+    return json11::Json();
   }
 }
 
-bool MMDBKVStore::reload()
+bool MMDBKVStore::keyExists(const std::string& key)
 {
-  return true;
+  auto addr = makeComboAddressFromRaw(key.size() == sizeof(in_addr) ? 4 : 6, key);
+  return d_mmdb->exists(addr);
+}
+bool MMDBKVStore::getValue(const std::string& key, std::string& value)
+{
+  auto addr = makeComboAddressFromRaw(key.size() == sizeof(in_addr) ? 4 : 6, key);
+  LuaAny ret;
+  bool result = d_mmdb->query(ret, d_queryParams, addr);
+  if (ret.type() == typeid(std::string)) {
+    value = boost::get<std::string>(ret);
+    return result;
+  }
+
+  json11::Json json = parseAny(ret);
+  json.dump(value);
+
+  return result;
 }
 #endif // HAVE_MMDB
index 456e87d9bf585138bc88fcbbf69fe2824a744513..030e03b19020ebc84effb62c1b139e2aa30f80d2 100644 (file)
@@ -24,6 +24,9 @@
 #include <memory>
 #include "dnsdist.hh"
 #include "logr.hh"
+#include "dnsdist-lua-types.hh"
+#include "ext/json11/json11.hpp"
+#include "iputils.hh"
 
 class KeyValueLookupKey
 {
@@ -239,16 +242,22 @@ private:
 class MMDBKVStore : public KeyValueStore
 {
 public:
-  MMDBKVStore(const std::string& fname, const std::string& modeStr, const std::string& field) :
-    d_mmdb(std::make_unique<MMDB>(fname, modeStr)), d_field(field) {}
+  MMDBKVStore(const std::shared_ptr<MMDB> mmdb, const LuaTypeOrArrayOf<std::string>& queryParams) :
+    d_mmdb(mmdb), d_originalParams(queryParams), d_queryParams(MMDB::convertParams(d_originalParams)) {};
 
   bool keyExists(const std::string& key) override;
   bool getValue(const std::string& key, std::string& value) override;
-  bool reload() override;
+  bool reload() override
+  {
+    return true;
+  }
 
 private:
-  std::unique_ptr<MMDB> d_mmdb{nullptr};
-  std::string d_field;
-};
+  std::shared_ptr<const Logr::Logger> getLogger() const;
+  std::shared_ptr<MMDB> d_mmdb;
+  const LuaTypeOrArrayOf<std::string> d_originalParams;
+  const boost::variant<const char*, std::vector<const char*>> d_queryParams;
 
+  json11::Json parseAny(const LuaAny& any);
+};
 #endif // HAVE_MMDB
index 63267bdfad196247224515da2b003a043bcb8ad4..9699b18a2ad1d9854193b14b389ef74e504313b8 100644 (file)
@@ -44,13 +44,13 @@ void setupLuaBindingsKVS([[maybe_unused]] LuaContext& luaCtx, [[maybe_unused]] b
 #endif /* HAVE_CDB */
 
 #ifdef HAVE_MMDB
-  luaCtx.writeFunction("newMMDBKVStore", [client](const std::string& fname, const std::string& field, std::optional<bool> mmap) {
+  luaCtx.writeFunction("newMMDBKVStore", [client](const std::shared_ptr<MMDB>& mmdb, const LuaTypeOrArrayOf<std::string>& queryParams) {
     if (client) {
       return std::shared_ptr<KeyValueStore>(nullptr);
     }
-    return std::shared_ptr<KeyValueStore>(new MMDBKVStore(fname, mmap ? "mmap" : "", field));
+    return std::shared_ptr<KeyValueStore>(new MMDBKVStore(mmdb, queryParams));
   });
-#endif /* HAVE_MMDB */
+#endif // HAVE_MMDB
 
 #if defined(HAVE_LMDB) || defined(HAVE_CDB) || defined(HAVE_MMDB)
   /* Key Value Store objects */
index 8761a4584eb7c2c73b83d4c943480546e0a666b3..79f8f93992875615b78edd697d8ea493cc86e4fa 100644 (file)
@@ -19,6 +19,7 @@
  * along with this program; if not, write to the Free Software
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
+#include "dnsdist-lua-types.hh"
 #include "dnsdist-lua.hh"
 #include "iputils.hh"
 #include <memory>
@@ -33,109 +34,33 @@ void setupLuaBindingsMMDB([[maybe_unused]] LuaContext& luaCtx)
     bool mmap{false};
     getOptionalValue<bool>(vars, "mmap", mmap);
 
-    auto mmdb = std::shared_ptr<MMDB>(new MMDB(name, mmap ? "mmap" : ""));
+    auto mmdb = std::make_shared<MMDB>(name, mmap ? "mmap" : "");
 
     return mmdb;
   });
 
-  luaCtx.registerFunction<std::optional<std::string> (std::shared_ptr<MMDB>::*)(const ComboAddress&)>("queryCountry", [](std::shared_ptr<MMDB>& mmdb, const ComboAddress& ip) {
-    std::optional<std::string> result{std::nullopt};
+  luaCtx.registerFunction<std::optional<LuaAny> (std::shared_ptr<MMDB>::*)(const LuaTypeOrArrayOf<std::string>&, const ComboAddress&)>("query", [](std::shared_ptr<MMDB>& mmdb, const LuaTypeOrArrayOf<std::string>& queryParams, const ComboAddress& ip) {
+    std::optional<LuaAny> result{std::nullopt};
     if (!mmdb) {
       return result;
     }
 
-    std::string value;
-    if (mmdb->queryCountry(value, ip)) {
-      result = value;
-    }
-
-    return result;
-  });
-
-  luaCtx.registerFunction<std::optional<std::string> (std::shared_ptr<MMDB>::*)(const ComboAddress&)>("queryContinent", [](std::shared_ptr<MMDB>& mmdb, const ComboAddress& ip) {
-    std::optional<std::string> result{std::nullopt};
-    if (!mmdb) {
-      return result;
-    }
-
-    std::string value;
-    if (mmdb->queryContinent(value, ip)) {
-      result = value;
-    }
-
-    return result;
-  });
-
-  luaCtx.registerFunction<std::optional<std::string> (std::shared_ptr<MMDB>::*)(const ComboAddress&)>("queryAS", [](std::shared_ptr<MMDB>& mmdb, const ComboAddress& ip) {
-    std::optional<std::string> result{std::nullopt};
-    if (!mmdb) {
-      return result;
-    }
-
-    std::string value;
-    if (mmdb->queryAS(value, ip)) {
-      result = value;
-    }
-
-    return result;
-  });
-
-  luaCtx.registerFunction<std::optional<std::string> (std::shared_ptr<MMDB>::*)(const ComboAddress&)>("queryASN", [](std::shared_ptr<MMDB>& mmdb, const ComboAddress& ip) {
-    std::optional<std::string> result{std::nullopt};
-    if (!mmdb) {
-      return result;
-    }
-
-    std::string value;
-    if (mmdb->queryASN(value, ip)) {
-      result = value;
-    }
-
-    return result;
-  });
-
-  luaCtx.registerFunction<std::optional<std::string> (std::shared_ptr<MMDB>::*)(const ComboAddress&)>("queryRegion", [](std::shared_ptr<MMDB>& mmdb, const ComboAddress& ip) {
-    std::optional<std::string> result{std::nullopt};
-    if (!mmdb) {
-      return result;
-    }
+    LuaAny value{false};
 
-    std::string value;
-    if (mmdb->queryRegion(value, ip)) {
+    if (mmdb->query(value, MMDB::convertParams(queryParams), ip)) {
       result = value;
     }
 
     return result;
   });
 
-  luaCtx.registerFunction<std::optional<std::string> (std::shared_ptr<MMDB>::*)(const ComboAddress&, const std::string&)>("queryCity", [](std::shared_ptr<MMDB>& mmdb, const ComboAddress& ip, const std::string& language) {
-    std::optional<std::string> result{std::nullopt};
+  luaCtx.registerFunction<bool (std::shared_ptr<MMDB>::*)(const ComboAddress&)>("exists", [](std::shared_ptr<MMDB>& mmdb, const ComboAddress& ip) {
+    bool result = false;
     if (!mmdb) {
       return result;
     }
 
-    std::string value;
-    if (mmdb->queryCity(value, ip, language)) {
-      result = value;
-    }
-
-    return result;
-  });
-
-  luaCtx.registerFunction<std::optional<std::tuple<double, double, int>> (std::shared_ptr<MMDB>::*)(const ComboAddress&)>("queryLocation", [](std::shared_ptr<MMDB>& mmdb, const ComboAddress& ip) {
-    std::optional<std::tuple<double, double, int>> result{std::nullopt};
-    if (!mmdb) {
-      return result;
-    }
-
-    double lat;
-    double lon;
-    int prec;
-    if (mmdb->queryLocation(lat, lon, prec, ip)) {
-      result = {lat, lon, prec};
-    }
-
-    return result;
+    return mmdb->exists(ip);
   });
 #endif
 }
diff --git a/pdns/dnsdistdist/dnsdist-lua-types.hh b/pdns/dnsdistdist/dnsdist-lua-types.hh
new file mode 100644 (file)
index 0000000..a76a91a
--- /dev/null
@@ -0,0 +1,39 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include <boost/any.hpp>
+#include <boost/variant/recursive_variant.hpp>
+#include <boost/variant/recursive_wrapper.hpp>
+#include <boost/variant/variant.hpp>
+#include <boost/variant/variant_fwd.hpp>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+template <class T>
+using LuaArray = std::vector<std::pair<int, T>>;
+template <class T>
+using LuaAssociativeTable = std::unordered_map<std::string, T>;
+template <class T>
+using LuaTypeOrArrayOf = boost::variant<T, LuaArray<T>>;
+using LuaAny = boost::make_recursive_variant<std::string, int64_t, uint64_t, double, bool, LuaArray<boost::recursive_variant_>, LuaAssociativeTable<boost::recursive_variant_>>::type;
index b07bc8b1cfcfc95eff75e4e00c9f725569ece1e6..5f9e57d68df30504d0475d092cbc7176476663fd 100644 (file)
 
 #include "dolog.hh"
 #include "dnsdist.hh"
+#include "dnsdist-lua-types.hh"
 
 #include "ext/luawrapper/include/LuaContext.hpp"
 
 extern RecursiveLockGuarded<LuaContext> g_lua;
 extern std::string g_outputBuffer; // locking for this is ok, as locked by g_luamutex
 
-template <class T>
-using LuaArray = std::vector<std::pair<int, T>>;
-template <class T>
-using LuaAssociativeTable = std::unordered_map<std::string, T>;
-template <class T>
-using LuaTypeOrArrayOf = boost::variant<T, LuaArray<T>>;
-
 using luaruleparams_t = LuaAssociativeTable<std::string>;
 
 using luadnsrule_t = boost::variant<string, LuaArray<std::string>, std::shared_ptr<DNSRule>, DNSName, LuaArray<DNSName>>;
index 1afb9750a72a4b22348c17c8e8cd02e955a81326..1de7074e805140a57737749a34b8b1ead61f97bc 100644 (file)
@@ -27,6 +27,7 @@
 #include <thread>
 #include <variant>
 
+#include "dnsdist-lua-types.hh"
 #include "ext/json11/json11.hpp"
 #include <yahttp/yahttp.hpp>
 
index 1273ef28b9acfebb33a719438aa156e373420c2d..19c5560670158bd5b74ecaa4743b3261601607ee 100644 (file)
@@ -20,6 +20,9 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
 
+#include "dnsdist-lua-types.hh"
+#include <boost/variant/get.hpp>
+#include <memory>
 #include <string>
 #ifdef HAVE_CONFIG_H
 #include "config.h"
 #include "mmdb.hh"
 #include <maxminddb.h>
 
-MMDB::MMDB(const std::string& fname, const std::string& modeStr)
+MMDB::MMDB(const std::string& fname, const std::string& modeStr) :
+  d_fname(fname)
 {
   int ec;
   int flags = 0;
-  if (modeStr == "")
+  if (modeStr == "") {
     /* for the benefit of ifdef */
-    ;
+  }
 #ifdef HAVE_MMAP
-  else if (modeStr == "mmap")
+  else if (modeStr == "mmap") {
     flags |= MMDB_MODE_MMAP;
+  }
 #endif
-  else
+  else {
     throw std::runtime_error(std::string("Unsupported mode ") + modeStr + ("for mmdb"));
+  }
   memset(&d_db, 0, sizeof(d_db));
   if ((ec = MMDB_open(fname.c_str(), flags, &d_db)) < 0)
     throw std::runtime_error(std::string("Cannot open ") + fname + std::string(": ") + std::string(MMDB_strerror(ec)));
@@ -50,113 +56,201 @@ MMDB::MMDB(const std::string& fname, const std::string& modeStr)
               dnsdist::logging::getTopLogger("mmdb")->info(Logr::Info, "Opened MMDB database", "path", Logging::Loggable(fname), "type", Logging::Loggable(d_db.metadata.database_type), "version", Logging::Loggable(std::to_string(d_db.metadata.binary_format_major_version) + "." + std::to_string(d_db.metadata.binary_format_minor_version))));
 }
 
-bool MMDB::queryCountry(string& ret, const ComboAddress& ip)
+bool MMDB::query(LuaAny& ret, const boost::variant<const char*, std::vector<const char*>>& queryParams, const ComboAddress& ip) const
 {
   MMDB_entry_data_s data;
   MMDB_lookup_result_s res;
-  if (!mmdbLookup(ip, res))
+  if (!mmdbLookup(ip, res)) {
     return false;
-  if (MMDB_get_value(&res.entry, &data, "country", "iso_code", NULL) != MMDB_SUCCESS || !data.has_data)
+  }
+
+  if (auto q = boost::get<const char*>(&queryParams)) {
+    if (MMDB_get_value(&res.entry, &data, q, NULL) != MMDB_SUCCESS || !data.has_data)
+      return false;
+  }
+  else if (auto params = boost::get<std::vector<const char*>>(&queryParams)) {
+    if (MMDB_aget_value(&res.entry, &data, &params->at(0)) != MMDB_SUCCESS || !data.has_data)
+      return false;
+  }
+
+  if (mmdbDecode(&data, ret)) {
+    return true;
+  }
+
+  MMDB_entry_s data_entry{&d_db, data.offset};
+  auto elistopt = getEntryList(&data_entry);
+  if (!elistopt) {
     return false;
-  ret = string(data.utf8_string, data.data_size);
-  return true;
+  }
+  auto elist = std::move(*elistopt);
+  auto first = elist.getFirst();
+  return mmdbDecodeEntryList(&first, ret);
 }
 
-bool MMDB::queryContinent(string& ret, const ComboAddress& ip)
+const boost::variant<const char*, std::vector<const char*>> MMDB::convertParams(const LuaTypeOrArrayOf<std::string>& queryParams)
 {
-  MMDB_entry_data_s data;
-  MMDB_lookup_result_s res;
-  if (!mmdbLookup(ip, res))
-    return false;
-  if (MMDB_get_value(&res.entry, &data, "continent", "code", NULL) != MMDB_SUCCESS || !data.has_data)
-    return false;
-  ret = string(data.utf8_string, data.data_size);
-  return true;
+  if (auto param = boost::get<std::string>(&queryParams)) {
+    return param->c_str();
+  }
+  else if (auto params = boost::get<std::vector<std::pair<int, std::string>>>(&queryParams)) {
+    auto paramsArray = std::vector<const char*>(params->size() + 1);
+    for (size_t i = 0; i < params->size(); ++i) {
+      paramsArray.at(i) = params->at(i).second.c_str();
+    }
+    paramsArray.at(params->size()) = NULL;
+    return paramsArray;
+  }
+  else {
+    return "";
+  }
 }
 
-bool MMDB::queryAS(string& ret, const ComboAddress& ip)
+std::shared_ptr<const Logr::Logger> MMDB::getLogger() const
 {
-  MMDB_entry_data_s data;
-  MMDB_lookup_result_s res;
-  if (!mmdbLookup(ip, res))
-    return false;
-  if (MMDB_get_value(&res.entry, &data, "autonomous_system_organization", NULL) != MMDB_SUCCESS || !data.has_data)
-    return false;
-  ret = string(data.utf8_string, data.data_size);
-  return true;
+  return dnsdist::logging::getTopLogger("mmdb")->withValues("path", Logging::Loggable(d_fname));
 }
 
-bool MMDB::queryASN(string& ret, const ComboAddress& ip)
+bool MMDB::mmdbDecode(MMDB_entry_data_s* data, LuaAny& ret) const
 {
-  MMDB_entry_data_s data;
-  MMDB_lookup_result_s res;
-  if (!mmdbLookup(ip, res))
+  switch (data->type) {
+  case MMDB_DATA_TYPE_BOOLEAN:
+    ret = data->boolean;
+    break;
+  case MMDB_DATA_TYPE_UTF8_STRING:
+    ret = string(data->utf8_string, data->data_size);
+    break;
+  case MMDB_DATA_TYPE_DOUBLE:
+    ret = data->double_value;
+    break;
+  case MMDB_DATA_TYPE_FLOAT:
+    ret = data->float_value;
+    break;
+  case MMDB_DATA_TYPE_INT32:
+    ret = static_cast<int64_t>(data->int32);
+    break;
+  case MMDB_DATA_TYPE_UINT16:
+    ret = static_cast<uint64_t>(data->uint16);
+    break;
+  case MMDB_DATA_TYPE_UINT32:
+    ret = static_cast<uint64_t>(data->uint32);
+    break;
+  case MMDB_DATA_TYPE_UINT64:
+    ret = static_cast<uint64_t>(data->uint64);
+    break;
+  default:
     return false;
-  if (MMDB_get_value(&res.entry, &data, "autonomous_system_number", NULL) != MMDB_SUCCESS || !data.has_data)
-    return false;
-  ret = std::to_string(data.uint32);
+  }
   return true;
 }
 
-bool MMDB::queryRegion(string& ret, const ComboAddress& ip)
+bool MMDB::mmdbDecodeEntryList(MMDB_entry_data_list_s** data, LuaAny& ret) const
 {
-  MMDB_entry_data_s data;
-  MMDB_lookup_result_s res;
-  if (!mmdbLookup(ip, res))
+  switch ((*data)->entry_data.type) {
+  case MMDB_DATA_TYPE_BOOLEAN:
+  case MMDB_DATA_TYPE_UTF8_STRING:
+  case MMDB_DATA_TYPE_DOUBLE:
+  case MMDB_DATA_TYPE_FLOAT:
+  case MMDB_DATA_TYPE_INT32:
+  case MMDB_DATA_TYPE_UINT16:
+  case MMDB_DATA_TYPE_UINT32:
+  case MMDB_DATA_TYPE_UINT64:
+    return mmdbDecode(&((*data)->entry_data), ret);
+  case MMDB_DATA_TYPE_ARRAY:
+    return mmdbDecodeArray(data, ret);
+    break;
+  case MMDB_DATA_TYPE_MAP:
+    return mmdbDecodeMap(data, ret);
+    break;
+  default:
     return false;
-  if (MMDB_get_value(&res.entry, &data, "subdivisions", "0", "iso_code", NULL) != MMDB_SUCCESS || !data.has_data)
-    return false;
-  ret = string(data.utf8_string, data.data_size);
-  return true;
+  }
 }
 
-bool MMDB::queryCity(string& ret, const ComboAddress& ip, const string& language)
+bool MMDB::mmdbDecodeMap(MMDB_entry_data_list_s** data, LuaAny& ret) const
 {
-  MMDB_entry_data_s data;
-  MMDB_lookup_result_s res;
-  if (!mmdbLookup(ip, res))
-    return false;
-  if ((MMDB_get_value(&res.entry, &data, "cities", "0", NULL) != MMDB_SUCCESS || !data.has_data) && (MMDB_get_value(&res.entry, &data, "city", "names", language.c_str(), NULL) != MMDB_SUCCESS || !data.has_data))
-    return false;
-  ret = string(data.utf8_string, data.data_size);
+  LuaAssociativeTable<LuaAny> result;
+
+  MMDB_entry_data_list_s* this_data = *data;
+
+  for (auto size = this_data->entry_data.data_size; size > 0; --size) {
+    *data = (*data)->next;
+
+    if (!*data) {
+      break;
+    }
+
+    if ((*data)->entry_data.type != MMDB_DATA_TYPE_UTF8_STRING) {
+      // Invalid key, stop decoding
+      return false;
+    }
+
+    std::string key{(*data)->entry_data.utf8_string, (*data)->entry_data.data_size};
+
+    *data = (*data)->next;
+    if (!*data) {
+      break;
+    }
+
+    LuaAny value;
+    if (!mmdbDecodeEntryList(data, value)) {
+      // Failed value decoding, stop decoding
+      return false;
+    }
+
+    result.emplace(std::move(key), std::move(value));
+  }
+
+  ret = result;
   return true;
 }
 
-bool MMDB::queryLocation(double& latitude, double& longitude,
-                         int& prec,
-                         const ComboAddress& ip)
+bool MMDB::mmdbDecodeArray(MMDB_entry_data_list_s** data, LuaAny& ret) const
 {
-  MMDB_entry_data_s data;
-  MMDB_lookup_result_s res;
-  if (!mmdbLookup(ip, res))
-    return false;
-  if (MMDB_get_value(&res.entry, &data, "location", "latitude", NULL) != MMDB_SUCCESS || !data.has_data)
-    return false;
-  latitude = data.double_value;
-  if (MMDB_get_value(&res.entry, &data, "location", "longitude", NULL) != MMDB_SUCCESS || !data.has_data)
-    return false;
-  longitude = data.double_value;
-  if (MMDB_get_value(&res.entry, &data, "location", "accuracy_radius", NULL) != MMDB_SUCCESS || !data.has_data)
-    return false;
-  prec = data.uint16;
+  LuaArray<LuaAny> result;
+
+  MMDB_entry_data_list_s* this_data = *data;
+
+  for (uint32_t i = 0; i < this_data->entry_data.data_size; ++i) {
+    *data = (*data)->next;
+
+    if (!*data) {
+      break;
+    }
+
+    LuaAny value;
+    if (!mmdbDecodeEntryList(data, value)) {
+      // Failed value decoding, stop decoding
+      return false;
+    }
+
+    result.emplace_back(i + 1, std::move(value));
+  }
+
+  ret = result;
   return true;
 }
 
-bool MMDB::mmdbLookup(const ComboAddress& ip, MMDB_lookup_result_s& res)
+bool MMDB::mmdbLookup(const ComboAddress& ip, MMDB_lookup_result_s& res) const
 {
   int mmdb_ec = 0;
   res = MMDB_lookup_sockaddr(&d_db, reinterpret_cast<const struct sockaddr*>(&ip), &mmdb_ec);
 
   if (mmdb_ec != MMDB_SUCCESS) {
-    VERBOSESLOG(infolog("mmdbLookup(%s) failed: %s", ip.toString(), MMDB_strerror(mmdb_ec)), dnsdist::logging::getTopLogger("mmdb")->error(Logr::Info, MMDB_strerror(mmdb_ec), "mmdbLookup failed", "ip", Logging::Loggable(ip)));
+    VERBOSESLOG(infolog("mmdbLookup(%s) failed: %s", ip.toString(), MMDB_strerror(mmdb_ec)), getLogger()->error(Logr::Info, MMDB_strerror(mmdb_ec), "mmdbLookup failed", "ip", Logging::Loggable(ip)));
   }
   else if (res.found_entry) {
-    // gl.netmask = res.netmask;
-    // /* If it's a IPv6 database, IPv4 netmasks are reduced from 128, so we need to deduct
-    //    96 to get from [96,128] => [0,32] range */
-    // if (!v6 && gl.netmask > 32)
-    //   gl.netmask -= 96;
     return true;
   }
   return false;
 }
+
+std::optional<MMDBEntryList> MMDB::getEntryList(MMDB_entry_s* entry) const
+{
+  MMDB_entry_data_list_s* entry_data_list;
+  int status = MMDB_get_entry_data_list(entry, &entry_data_list);
+
+  if (status != MMDB_SUCCESS) {
+    return std::nullopt;
+  }
+  return {entry_data_list};
+}
index 5b8792406b6e10b279758aa65ff34bd4467ad5be..bfab18fc86f33cffffb68fd30c543d6fab7eca77 100644 (file)
  */
 #pragma once
 
+#include "dnsdist-lua-types.hh"
 #include "iputils.hh"
 #include <maxminddb.h>
+#include <memory>
 #include <string>
 
+class MMDBEntryList;
+
 class MMDB
 {
 public:
   MMDB(const std::string& fname, const std::string& modeStr);
 
-  bool queryCountry(std::string& ret, const ComboAddress& ip);
-  bool queryContinent(std::string& ret, const ComboAddress& ip);
-  bool queryAS(std::string& ret, const ComboAddress& ip);
-  bool queryASN(std::string& ret, const ComboAddress& ip);
-  bool queryRegion(std::string& ret, const ComboAddress& ip);
-  bool queryCity(std::string& ret, const ComboAddress& ip, const std::string& language);
-  bool queryLocation(double& latitude, double& longitude, int& prec, const ComboAddress& ip);
-  bool exists(const ComboAddress& ip)
+  static const boost::variant<const char*, std::vector<const char*>> convertParams(const LuaTypeOrArrayOf<std::string>& queryParams);
+  bool query(LuaAny& ret, const boost::variant<const char*, std::vector<const char*>>& queryParams, const ComboAddress& ip) const;
+  bool exists(const ComboAddress& ip) const
   {
     MMDB_lookup_result_s res;
     return mmdbLookup(ip, res);
   }
+  const std::string& file_name() const
+  {
+    return d_fname;
+  }
 
   ~MMDB() { MMDB_close(&d_db); };
 
 private:
+  std::string d_fname;
   MMDB_s d_db;
 
-  bool mmdbLookup(const ComboAddress& ip, MMDB_lookup_result_s& res);
+  std::shared_ptr<const Logr::Logger> getLogger() const;
+  // Decodes one of the basic types (no arrays and maps)
+  bool mmdbDecode(MMDB_entry_data_s* data, LuaAny& ret) const;
+  // Decodes whole entry data list (supports arrays and maps too)
+  bool mmdbDecodeEntryList(MMDB_entry_data_list_s** data, LuaAny& ret) const;
+  bool mmdbDecodeMap(MMDB_entry_data_list_s** data, LuaAny& ret) const;
+  bool mmdbDecodeArray(MMDB_entry_data_list_s** data, LuaAny& ret) const;
+  bool mmdbLookup(const ComboAddress& ip, MMDB_lookup_result_s& res) const;
+  std::optional<MMDBEntryList> getEntryList(MMDB_entry_s* entry) const;
+};
+
+class MMDBEntryList
+{
+public:
+  MMDBEntryList(MMDB_entry_data_list_s* first) :
+    d_entry_list_first(first, MMDB_free_entry_data_list) {}
+
+  MMDB_entry_data_list_s* getFirst() const
+  {
+    return d_entry_list_first.get();
+  }
+
+private:
+  std::unique_ptr<MMDB_entry_data_list_s, decltype(&MMDB_free_entry_data_list)> d_entry_list_first;
 };