]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Merge pull request #8505 from rgacogne/dnsdist-lua-ffi
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 18 Feb 2020 13:37:11 +0000 (14:37 +0100)
committerGitHub <noreply@github.com>
Tue, 18 Feb 2020 13:37:11 +0000 (14:37 +0100)
dnsdist: Implement LuaFFIRule, LuaFFIAction and LuaFFIResponseAction

33 files changed:
pdns/dnsdist-carbon.cc
pdns/dnsdist-console.cc
pdns/dnsdist-lbpolicies.hh [new file with mode: 0644]
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdist-lua-bindings.cc
pdns/dnsdist-lua-inspection.cc
pdns/dnsdist-lua-rules.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-lua.hh
pdns/dnsdist-snmp.cc
pdns/dnsdist-web.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/.gitignore
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-backend.cc [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-lbpolicies.cc [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-lbpolicies.hh [new symlink]
pdns/dnsdistdist/dnsdist-lua-bindings-kvs.cc
pdns/dnsdistdist/dnsdist-lua-ffi-interface.h [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-lua-ffi.cc [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-lua-ffi.hh [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-rules.hh
pdns/dnsdistdist/docs/advanced/tuning.rst
pdns/dnsdistdist/docs/guides/serverselection.rst
pdns/dnsdistdist/docs/reference/config.rst
pdns/dnsdistdist/docs/reference/netmask.rst [new file with mode: 0644]
pdns/dnsdistdist/docs/rules-actions.rst
pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc [new file with mode: 0644]
regression-tests.dnsdist/test_Advanced.py
regression-tests.dnsdist/test_DOH.py
regression-tests.dnsdist/test_EDNSOptions.py

index 72473ca00d286dddfc8f36ca03d0d32ac090cb1d..26702dcb65db8e6320e31fc157dec8740f515920 100644 (file)
@@ -87,7 +87,7 @@ try
         }
         auto states = g_dstates.getLocal();
         for(const auto& state : *states) {
-          string serverName = state->name.empty() ? (state->remote.toString() + ":" + std::to_string(state->remote.getPort())) : state->getName();
+          string serverName = state->getName().empty() ? (state->remote.toString() + ":" + std::to_string(state->remote.getPort())) : state->getName();
           boost::replace_all(serverName, ".", "_");
           const string base = namespace_name + "." + hostname + "." + instance_name + ".servers." + serverName + ".";
           str<<base<<"queries" << ' ' << state->queries.load() << " " << now << "\r\n";
index 6551f3a8341839f04b92d2303957a58912627d4c..6c0dc6b7963ad2a1b6a7d15979f9df9f23d4b59d 100644 (file)
@@ -424,7 +424,11 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "LogAction", true, "[filename], [binary], [append], [buffered]", "Log a line for each query, to the specified file if any, to the console (require verbose) otherwise. When logging to a file, the `binary` optional parameter specifies whether we log in binary form (default) or in textual form, the `append` optional parameter specifies whether we open the file for appending or truncate each time (default), and the `buffered` optional parameter specifies whether writes to the file are buffered (default) or not." },
   { "LogResponseAction", true, "[filename], [append], [buffered]", "Log a line for each response, to the specified file if any, to the console (require verbose) otherwise. The `append` optional parameter specifies whether we open the file for appending or truncate each time (default), and the `buffered` optional parameter specifies whether writes to the file are buffered (default) or not." },
   { "LuaAction", true, "function", "Invoke a Lua function that accepts a DNSQuestion" },
+  { "LuaFFIAction", true, "function", "Invoke a Lua FFI function that accepts a DNSQuestion" },
+  { "LuaFFIResponseAction", true, "function", "Invoke a Lua FFI function that accepts a DNSResponse" },
+  { "LuaFFIRule", true, "function", "Invoke a Lua FFI function that filters DNS questions" },
   { "LuaResponseAction", true, "function", "Invoke a Lua function that accepts a DNSResponse" },
+  { "LuaRule", true, "function", "Invoke a Lua function that filters DNS questions" },
   { "MacAddrAction", true, "option", "Add the source MAC address to the query as EDNS0 option option. This action is currently only supported on Linux. Subsequent rules are processed after this action" },
   { "makeIPCipherKey", true, "password", "generates a 16-byte key that can be used to pseudonymize IP addresses with IP cipher" },
   { "makeKey", true, "", "generate a new server access key, emit configuration line ready for pasting" },
@@ -530,6 +534,7 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "setSecurityPollSuffix", true, "suffix", "set the security polling suffix to the specified value" },
   { "setServerPolicy", true, "policy", "set server selection policy to that policy" },
   { "setServerPolicyLua", true, "name, function", "set server selection policy to one named 'name' and provided by 'function'" },
+  { "setServerPolicyLuaFFI", true, "name, function", "set server selection policy to one named 'name' and provided by the Lua FFI 'function'" },
   { "setServFailWhenNoServer", true, "bool", "if set, return a ServFail when no servers are available, instead of the default behaviour of dropping the query" },
   { "setStaleCacheEntriesTTL", true, "n", "allows using cache entries expired for at most n seconds when there is no backend available to answer for a query" },
   { "setSyslogFacility", true, "facility", "set the syslog logging facility to 'facility'. Defaults to LOG_DAEMON" },
diff --git a/pdns/dnsdist-lbpolicies.hh b/pdns/dnsdist-lbpolicies.hh
new file mode 100644 (file)
index 0000000..587c22d
--- /dev/null
@@ -0,0 +1,78 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+struct dnsdist_ffi_servers_list_t;
+struct dnsdist_ffi_server_t;
+struct dnsdist_ffi_dnsquestion_t;
+
+struct DownstreamState;
+
+struct ServerPolicy
+{
+  template <class T> using NumberedVector = std::vector<std::pair<unsigned int, T> >;
+  using NumberedServerVector = NumberedVector<shared_ptr<DownstreamState>>;
+  typedef std::function<shared_ptr<DownstreamState>(const NumberedServerVector& servers, const DNSQuestion*)> policyfunc_t;
+  typedef std::function<unsigned int(dnsdist_ffi_servers_list_t* servers, dnsdist_ffi_dnsquestion_t* dq)> ffipolicyfunc_t;
+
+  ServerPolicy(const std::string& name_, policyfunc_t policy_, bool isLua_): name(name_), policy(policy_), isLua(isLua_)
+  {
+  }
+  ServerPolicy(const std::string& name_, ffipolicyfunc_t policy_): name(name_), ffipolicy(policy_), isLua(true), isFFI(true)
+  {
+  }
+  ServerPolicy()
+  {
+  }
+
+  string name;
+  policyfunc_t policy;
+  ffipolicyfunc_t ffipolicy;
+  bool isLua{false};
+  bool isFFI{false};
+
+  std::string toString() const {
+    return string("ServerPolicy") + (isLua ? " (Lua)" : "") + " \"" + name + "\"";
+  }
+};
+
+struct ServerPool;
+
+using pools_t=map<std::string,std::shared_ptr<ServerPool>>;
+std::shared_ptr<ServerPool> getPool(const pools_t& pools, const std::string& poolName);
+std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string& poolName);
+void setPoolPolicy(pools_t& pools, const string& poolName, std::shared_ptr<ServerPolicy> policy);
+void addServerToPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server);
+void removeServerFromPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server);
+
+ServerPolicy::NumberedServerVector getDownstreamCandidates(const map<std::string,std::shared_ptr<ServerPool>>& pools, const std::string& poolName);
+
+std::shared_ptr<DownstreamState> firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+
+std::shared_ptr<DownstreamState> leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::shared_ptr<DownstreamState> wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::shared_ptr<DownstreamState> whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::shared_ptr<DownstreamState> whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash);
+std::shared_ptr<DownstreamState> chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::shared_ptr<DownstreamState> chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash);
+std::shared_ptr<DownstreamState> roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::shared_ptr<DownstreamState> getSelectedBackendFromPolicy(const ServerPolicy& policy, const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq);
index 9134ddb7134d0ad0dd91d83861cbe552e49793ea..82ae55ffe344ad774302e6a65303964e4c1c6b5d 100644 (file)
@@ -24,6 +24,7 @@
 #include "dnsdist.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsdist-lua.hh"
+#include "dnsdist-lua-ffi.hh"
 #include "dnsdist-protobuf.hh"
 #include "dnsdist-kvs.hh"
 
@@ -354,51 +355,170 @@ public:
   }
 };
 
-DNSAction::Action LuaAction::operator()(DNSQuestion* dq, std::string* ruleresult) const
+class LuaAction : public DNSAction
 {
-  std::lock_guard<std::mutex> lock(g_luamutex);
-  try {
-    auto ret = d_func(dq);
-    if (ruleresult) {
-      if (boost::optional<std::string> rule = std::get<1>(ret)) {
-        *ruleresult = *rule;
+public:
+  typedef std::function<std::tuple<int, boost::optional<string> >(DNSQuestion* dq)> func_t;
+  LuaAction(const LuaAction::func_t& func) : d_func(func)
+  {}
+
+  DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
+  {
+    std::lock_guard<std::mutex> lock(g_luamutex);
+    try {
+      auto ret = d_func(dq);
+      if (ruleresult) {
+        if (boost::optional<std::string> rule = std::get<1>(ret)) {
+          *ruleresult = *rule;
+        }
+        else {
+          // default to empty string
+          ruleresult->clear();
+        }
       }
-      else {
-        // default to empty string
-        ruleresult->clear();
+      return static_cast<Action>(std::get<0>(ret));
+    } catch (const std::exception &e) {
+      warnlog("LuaAction failed inside Lua, returning ServFail: %s", e.what());
+    } catch (...) {
+      warnlog("LuaAction failed inside Lua, returning ServFail: [unknown exception]");
+    }
+    return DNSAction::Action::ServFail;
+  }
+
+  string toString() const override
+  {
+    return "Lua script";
+  }
+private:
+  func_t d_func;
+};
+
+class LuaResponseAction : public DNSResponseAction
+{
+public:
+  typedef std::function<std::tuple<int, boost::optional<string> >(DNSResponse* dr)> func_t;
+  LuaResponseAction(const LuaResponseAction::func_t& func) : d_func(func)
+  {}
+  DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override
+  {
+    std::lock_guard<std::mutex> lock(g_luamutex);
+    try {
+      auto ret = d_func(dr);
+      if(ruleresult) {
+        if (boost::optional<std::string> rule = std::get<1>(ret)) {
+          *ruleresult = *rule;
+        }
+        else {
+          // default to empty string
+          ruleresult->clear();
+        }
       }
+      return static_cast<Action>(std::get<0>(ret));
+    } catch (const std::exception &e) {
+      warnlog("LuaResponseAction failed inside Lua, returning ServFail: %s", e.what());
+    } catch (...) {
+      warnlog("LuaResponseAction failed inside Lua, returning ServFail: [unknown exception]");
     }
-    return (Action)std::get<0>(ret);
-  } catch (std::exception &e) {
-    warnlog("LuaAction failed inside lua, returning ServFail: %s", e.what());
-  } catch (...) {
-    warnlog("LuaAction failed inside lua, returning ServFail: [unknown exception]");
+    return DNSResponseAction::Action::ServFail;
   }
-  return DNSAction::Action::ServFail;
-}
 
-DNSResponseAction::Action LuaResponseAction::operator()(DNSResponse* dr, std::string* ruleresult) const
+  string toString() const override
+  {
+    return "Lua response script";
+  }
+private:
+  func_t d_func;
+};
+
+class LuaFFIAction: public DNSAction
 {
-  std::lock_guard<std::mutex> lock(g_luamutex);
-  try {
-    auto ret = d_func(dr);
-    if(ruleresult) {
-      if (boost::optional<std::string> rule = std::get<1>(ret)) {
-        *ruleresult = *rule;
+public:
+  typedef std::function<int(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+
+  LuaFFIAction(const LuaFFIAction::func_t& func): d_func(func)
+  {
+  }
+
+  DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
+  {
+    dnsdist_ffi_dnsquestion_t dqffi(dq);
+    try {
+      std::lock_guard<std::mutex> lock(g_luamutex);
+
+      auto ret = d_func(&dqffi);
+      if (ruleresult) {
+        if (dqffi.result) {
+          *ruleresult = *dqffi.result;
+        }
+        else {
+          // default to empty string
+          ruleresult->clear();
+        }
       }
-      else {
-        // default to empty string
-        ruleresult->clear();
+      return static_cast<DNSAction::Action>(ret);
+    } catch (const std::exception &e) {
+      warnlog("LuaFFIAction failed inside Lua, returning ServFail: %s", e.what());
+    } catch (...) {
+      warnlog("LuaFFIAction failed inside Lua, returning ServFail: [unknown exception]");
+    }
+    return DNSAction::Action::ServFail;
+  }
+
+  string toString() const override
+  {
+    return "Lua FFI script";
+  }
+private:
+  func_t d_func;
+};
+
+
+class LuaFFIResponseAction: public DNSResponseAction
+{
+public:
+  typedef std::function<int(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+
+  LuaFFIResponseAction(const LuaFFIResponseAction::func_t& func): d_func(func)
+  {
+  }
+
+  DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override
+  {
+    DNSQuestion* dq = dynamic_cast<DNSQuestion*>(dr);
+    if (dq == nullptr) {
+      return DNSResponseAction::Action::ServFail;
+    }
+
+    dnsdist_ffi_dnsquestion_t dqffi(dq);
+    try {
+      std::lock_guard<std::mutex> lock(g_luamutex);
+
+      auto ret = d_func(&dqffi);
+      if (ruleresult) {
+        if (dqffi.result) {
+          *ruleresult = *dqffi.result;
+        }
+        else {
+          // default to empty string
+          ruleresult->clear();
+        }
       }
+      return static_cast<DNSResponseAction::Action>(ret);
+    } catch (const std::exception &e) {
+      warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: %s", e.what());
+    } catch (...) {
+      warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: [unknown exception]");
     }
-    return (Action)std::get<0>(ret);
-  } catch (std::exception &e) {
-    warnlog("LuaResponseAction failed inside lua, returning ServFail: %s", e.what());
-  } catch (...) {
-    warnlog("LuaResponseAction failed inside lua, returning ServFail: [unknown exception]");
+    return DNSResponseAction::Action::ServFail;
   }
-  return DNSResponseAction::Action::ServFail;
-}
+
+  string toString() const override
+  {
+    return "Lua FFI script";
+  }
+private:
+  func_t d_func;
+};
 
 DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresult) const
 {
@@ -1423,6 +1543,11 @@ void setupLuaActions()
       return std::shared_ptr<DNSAction>(new LuaAction(func));
     });
 
+  g_lua.writeFunction("LuaFFIAction", [](LuaFFIAction::func_t func) {
+      setLuaSideEffect();
+      return std::shared_ptr<DNSAction>(new LuaFFIAction(func));
+    });
+
   g_lua.writeFunction("NoRecurseAction", []() {
       return std::shared_ptr<DNSAction>(new NoRecurseAction);
     });
@@ -1547,6 +1672,11 @@ void setupLuaActions()
       return std::shared_ptr<DNSResponseAction>(new LuaResponseAction(func));
     });
 
+  g_lua.writeFunction("LuaFFIResponseAction", [](LuaFFIResponseAction::func_t func) {
+      setLuaSideEffect();
+      return std::shared_ptr<DNSResponseAction>(new LuaFFIResponseAction(func));
+    });
+
   g_lua.writeFunction("RemoteLogAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<std::unordered_map<std::string, std::string>> vars) {
       if (logger) {
         // avoids potentially-evaluated-expression warning with clang.
index eff46f3ab5aa7c835ab483a0492dbabea2305262..71840e7f36481f79cc40cd4353102e3304c6d898 100644 (file)
@@ -64,25 +64,10 @@ void setupLuaBindingsDNSQuestion()
       return *dq.ednsOptions;
     });
   g_lua.registerFunction<std::string(DNSQuestion::*)(void)>("getTrailingData", [](const DNSQuestion& dq) {
-      const char* message = reinterpret_cast<const char*>(dq.dh);
-      const uint16_t messageLen = getDNSPacketLength(message, dq.len);
-      const std::string tail = std::string(message + messageLen, dq.len - messageLen);
-      return tail;
+      return dq.getTrailingData();
     });
   g_lua.registerFunction<bool(DNSQuestion::*)(std::string)>("setTrailingData", [](DNSQuestion& dq, const std::string& tail) {
-      char* message = reinterpret_cast<char*>(dq.dh);
-      const uint16_t messageLen = getDNSPacketLength(message, dq.len);
-      const uint16_t tailLen = tail.size();
-      if(tailLen > (dq.size - messageLen)) {
-        return false;
-      }
-
-      /* Update length and copy data from the Lua string. */
-      dq.len = messageLen + tailLen;
-      if(tailLen > 0) {
-        tail.copy(message + messageLen, tailLen);
-      }
-      return true;
+      return dq.setTrailingData(tail);
     });
 
   g_lua.registerFunction<std::string(DNSQuestion::*)()>("getServerNameIndication", [](const DNSQuestion& dq) {
@@ -150,25 +135,10 @@ void setupLuaBindingsDNSQuestion()
         editDNSPacketTTL((char*) dr.dh, dr.len, editFunc);
       });
   g_lua.registerFunction<std::string(DNSResponse::*)(void)>("getTrailingData", [](const DNSResponse& dq) {
-      const char* message = reinterpret_cast<const char*>(dq.dh);
-      const uint16_t messageLen = getDNSPacketLength(message, dq.len);
-      const std::string tail = std::string(message + messageLen, dq.len - messageLen);
-      return tail;
+      return dq.getTrailingData();
     });
   g_lua.registerFunction<bool(DNSResponse::*)(std::string)>("setTrailingData", [](DNSResponse& dq, const std::string& tail) {
-      char* message = reinterpret_cast<char*>(dq.dh);
-      const uint16_t messageLen = getDNSPacketLength(message, dq.len);
-      const uint16_t tailLen = tail.size();
-      if(tailLen > (dq.size - messageLen)) {
-        return false;
-      }
-
-      /* Update length and copy data from the Lua string. */
-      dq.len = messageLen + tailLen;
-      if(tailLen > 0) {
-        tail.copy(message + messageLen, tailLen);
-      }
-      return true;
+      return dq.setTrailingData(tail);
     });
 
   g_lua.registerFunction<void(DNSResponse::*)(std::string, std::string)>("setTag", [](DNSResponse& dr, const std::string& strLabel, const std::string& strValue) {
index 06cc42f9da88f566e31916730b729f93a40998cf..f86f79140d94e5cac5b86057238bb13e2c179442 100644 (file)
@@ -57,10 +57,12 @@ void setupLuaBindings(bool client)
       return string("No exception");
     });
   /* ServerPolicy */
-  g_lua.writeFunction("newServerPolicy", [](string name, policyfunc_t policy) { return ServerPolicy{name, policy, true};});
+  g_lua.writeFunction("newServerPolicy", [](string name, ServerPolicy::policyfunc_t policy) { return std::make_shared<ServerPolicy>(name, policy, true);});
   g_lua.registerMember("name", &ServerPolicy::name);
   g_lua.registerMember("policy", &ServerPolicy::policy);
+  g_lua.registerMember("ffipolicy", &ServerPolicy::ffipolicy);
   g_lua.registerMember("isLua", &ServerPolicy::isLua);
+  g_lua.registerMember("isFFI", &ServerPolicy::isFFI);
   g_lua.registerFunction("toString", &ServerPolicy::toString);
 
   g_lua.writeVariable("firstAvailable", ServerPolicy{"firstAvailable", firstAvailable, false});
@@ -117,7 +119,7 @@ void setupLuaBindings(bool client)
     [](DownstreamState& s, int newWeight) {s.setWeight(newWeight);}
   );
   g_lua.registerMember("order", &DownstreamState::order);
-  g_lua.registerMember("name", &DownstreamState::name);
+  g_lua.registerMember<const std::string(DownstreamState::*)>("name", [](const DownstreamState& backend) -> const std::string { return backend.getName(); }, [](DownstreamState& backend, const std::string& newName) { backend.setName(newName); });
   g_lua.registerFunction<std::string(DownstreamState::*)()>("getID", [](const DownstreamState& s) { return boost::uuids::to_string(s.id); });
 
   /* dnsheader */
@@ -172,6 +174,29 @@ void setupLuaBindings(bool client)
 
   /* ComboAddress */
   g_lua.writeFunction("newCA", [](const std::string& name) { return ComboAddress(name); });
+  g_lua.writeFunction("newCAFromRaw", [](const std::string& raw, boost::optional<uint16_t> port) {
+                                        if (raw.size() == 4) {
+                                          struct sockaddr_in sin4;
+                                          memset(&sin4, 0, sizeof(sin4));
+                                          sin4.sin_family = AF_INET;
+                                          memcpy(&sin4.sin_addr.s_addr, raw.c_str(), raw.size());
+                                          if (port) {
+                                            sin4.sin_port = htons(*port);
+                                          }
+                                          return ComboAddress(&sin4);
+                                        }
+                                        else if (raw.size() == 16) {
+                                          struct sockaddr_in6 sin6;
+                                          memset(&sin6, 0, sizeof(sin6));
+                                          sin6.sin6_family = AF_INET6;
+                                          memcpy(&sin6.sin6_addr.s6_addr, raw.c_str(), raw.size());
+                                          if (port) {
+                                            sin6.sin6_port = htons(*port);
+                                          }
+                                          return ComboAddress(&sin6);
+                                        }
+                                        return ComboAddress();
+                                      });
   g_lua.registerFunction<string(ComboAddress::*)()>("tostring", [](const ComboAddress& ca) { return ca.toString(); });
   g_lua.registerFunction<string(ComboAddress::*)()>("tostringWithPort", [](const ComboAddress& ca) { return ca.toStringWithPort(); });
   g_lua.registerFunction<string(ComboAddress::*)()>("toString", [](const ComboAddress& ca) { return ca.toString(); });
@@ -188,10 +213,12 @@ void setupLuaBindings(bool client)
   g_lua.registerFunction("isPartOf", &DNSName::isPartOf);
   g_lua.registerFunction<bool(DNSName::*)()>("chopOff", [](DNSName&dn ) { return dn.chopOff(); });
   g_lua.registerFunction<unsigned int(DNSName::*)()>("countLabels", [](const DNSName& name) { return name.countLabels(); });
+  g_lua.registerFunction<size_t(DNSName::*)()>("hash", [](const DNSName& name) { return name.hash(); });
   g_lua.registerFunction<size_t(DNSName::*)()>("wirelength", [](const DNSName& name) { return name.wirelength(); });
   g_lua.registerFunction<string(DNSName::*)()>("tostring", [](const DNSName&dn ) { return dn.toString(); });
   g_lua.registerFunction<string(DNSName::*)()>("toString", [](const DNSName&dn ) { return dn.toString(); });
   g_lua.writeFunction("newDNSName", [](const std::string& name) { return DNSName(name); });
+  g_lua.writeFunction("newDNSNameFromRaw", [](const std::string& name) { return DNSName(name.c_str(), name.size(), 0, false); });
   g_lua.writeFunction("newSuffixMatchNode", []() { return SuffixMatchNode(); });
   g_lua.writeFunction("newDNSNameSet", []() { return DNSNameSet(); });
 
@@ -233,6 +260,34 @@ void setupLuaBindings(bool client)
   });
   g_lua.registerFunction("check",(bool (SuffixMatchNode::*)(const DNSName&) const) &SuffixMatchNode::check);
 
+  /* Netmask */
+  g_lua.writeFunction("newNetmask", [](boost::variant<std::string,ComboAddress> s, boost::optional<uint8_t> bits) {
+    if (s.type() == typeid(ComboAddress)) {
+      auto ca = boost::get<ComboAddress>(s);
+      if (bits) {
+        return Netmask(ca, *bits);
+      }
+      return Netmask(ca);
+    }
+    else if (s.type() == typeid(std::string)) {
+      auto str = boost::get<std::string>(s);
+      return Netmask(str);
+    }
+    throw std::runtime_error("Invalid parameter passed to 'newNetmask()'");
+  });
+  g_lua.registerFunction("empty", &Netmask::empty);
+  g_lua.registerFunction("getBits", &Netmask::getBits);
+  g_lua.registerFunction<ComboAddress(Netmask::*)()>("getNetwork", [](const Netmask& nm) { return nm.getNetwork(); } ); // const reference makes this necessary
+  g_lua.registerFunction<ComboAddress(Netmask::*)()>("getMaskedNetwork", [](const Netmask& nm) { return nm.getMaskedNetwork(); } );
+  g_lua.registerFunction("isIpv4", &Netmask::isIPv4);
+  g_lua.registerFunction("isIPv4", &Netmask::isIPv4);
+  g_lua.registerFunction("isIpv6", &Netmask::isIPv6);
+  g_lua.registerFunction("isIPv6", &Netmask::isIPv6);
+  g_lua.registerFunction("match", (bool (Netmask::*)(const string&) const)&Netmask::match);
+  g_lua.registerFunction("toString", &Netmask::toString);
+  g_lua.registerEqFunction(&Netmask::operator==);
+  g_lua.registerToStringFunction(&Netmask::toString);
+
   /* NetmaskGroup */
   g_lua.writeFunction("newNMG", []() { return NetmaskGroup(); });
   g_lua.registerFunction<void(NetmaskGroup::*)(const std::string&mask)>("addMask", [](NetmaskGroup&nmg, const std::string& mask)
index 0d37ff320c583d27ce214151de2bc9a43abfe125..d01a007b11033247ed72836144c178f36141ebdc 100644 (file)
@@ -583,7 +583,7 @@ void setupLuaInspection()
       auto states = g_dstates.getLocal();
       counter = 0;
       for(const auto& s : *states) {
-        ret << (fmt % counter % s->name % s->remote.toStringWithPort() % s->tcpCurrentConnections % s->tcpDiedSendingQuery % s->tcpDiedReadingResponse % s->tcpGaveUp % s->tcpReadTimeouts % s->tcpWriteTimeouts % s->tcpAvgQueriesPerConnection % s->tcpAvgConnectionDuration) << endl;
+        ret << (fmt % counter % s->getName() % s->remote.toStringWithPort() % s->tcpCurrentConnections % s->tcpDiedSendingQuery % s->tcpDiedReadingResponse % s->tcpGaveUp % s->tcpReadTimeouts % s->tcpWriteTimeouts % s->tcpAvgQueriesPerConnection % s->tcpAvgConnectionDuration) << endl;
         ++counter;
       }
 
index 923c0fabfdc98c58b1fa074a0570c6b661b2392f..c33af0d7f166f6307de838b5f9b72f223f13d530 100644 (file)
@@ -486,4 +486,12 @@ void setupLuaRules()
   g_lua.writeFunction("KeyValueStoreLookupRule", [](std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey) {
       return std::shared_ptr<DNSRule>(new KeyValueStoreLookupRule(kvs, lookupKey));
     });
+
+  g_lua.writeFunction("LuaRule", [](LuaRule::func_t func) {
+      return std::shared_ptr<DNSRule>(new LuaRule(func));
+    });
+
+  g_lua.writeFunction("LuaFFIRule", [](LuaFFIRule::func_t func) {
+      return std::shared_ptr<DNSRule>(new LuaFFIRule(func));
+    });
 }
index 8e69b44754d3085d897c47f37ccc882bcfcc7e22..6ec08894fc5f46b8b85f467838e6433e5a539da2 100644 (file)
@@ -36,6 +36,9 @@
 #include "dnsdist-ecs.hh"
 #include "dnsdist-healthchecks.hh"
 #include "dnsdist-lua.hh"
+#ifdef LUAJIT_VERSION
+#include "dnsdist-lua-ffi.hh"
+#endif /* LUAJIT_VERSION */
 #include "dnsdist-rings.hh"
 #include "dnsdist-secpoll.hh"
 
@@ -207,7 +210,7 @@ static void parseTLSConfig(TLSConfig& config, const std::string& context, boost:
 
 #endif // defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
 
-void setupLuaConfig(bool client, bool configCheck)
+static void setupLuaConfig(bool client, bool configCheck)
 {
   typedef std::unordered_map<std::string, boost::variant<bool, std::string, vector<pair<int, std::string> >, DownstreamState::checkfunc_t > > newserver_t;
   g_lua.writeFunction("inClientStartup", [client]() {
@@ -323,6 +326,9 @@ void setupLuaConfig(bool client, bool configCheck)
 
       // create but don't connect the socket in client or check-config modes
       ret=std::make_shared<DownstreamState>(serverAddr, sourceAddr, sourceItf, sourceItfName, numberOfSockets, !(client || configCheck));
+      if (!(client || configCheck)) {
+        infolog("Added downstream server %s", serverAddr.toStringWithPort());
+      }
 
       if(vars.count("qps")) {
         int qpsVal=std::stoi(boost::get<string>(vars["qps"]));
@@ -383,7 +389,7 @@ void setupLuaConfig(bool client, bool configCheck)
       }
 
       if(vars.count("name")) {
-        ret->name=boost::get<string>(vars["name"]);
+        ret->setName(boost::get<string>(vars["name"]));
       }
 
       if (vars.count("id")) {
@@ -533,20 +539,6 @@ void setupLuaConfig(bool client, bool configCheck)
                         g_dstates.setState(states);
                       } );
 
-  g_lua.writeFunction("setServerPolicy", [](ServerPolicy policy)  {
-      setLuaSideEffect();
-      g_policy.setState(policy);
-    });
-  g_lua.writeFunction("setServerPolicyLua", [](string name, policyfunc_t policy)  {
-      setLuaSideEffect();
-      g_policy.setState(ServerPolicy{name, policy, true});
-    });
-
-  g_lua.writeFunction("showServerPolicy", []() {
-      setLuaSideEffect();
-      g_outputBuffer=g_policy.getLocal()->name+"\n";
-    });
-
   g_lua.writeFunction("truncateTC", [](bool tc) { setLuaSideEffect(); g_truncateTC=tc; });
   g_lua.writeFunction("fixupCase", [](bool fu) { setLuaSideEffect(); g_fixupCase=fu; });
 
@@ -693,11 +685,11 @@ void setupLuaConfig(bool client, bool configCheck)
             pools+=p;
           }
           if (showUUIDs) {
-            ret << (fmt % counter % s->name % s->remote.toStringWithPort() %
+            ret << (fmt % counter % s->getName() % s->remote.toStringWithPort() %
                     status %
                     s->queryLoad % s->qps.getRate() % s->order % s->weight % s->queries.load() % s->reuseds.load() % (s->dropRate) % (s->latencyUsec/1000.0) % s->outstanding.load() % pools % s->id) << endl;
           } else {
-            ret << (fmt % counter % s->name % s->remote.toStringWithPort() %
+            ret << (fmt % counter % s->getName() % s->remote.toStringWithPort() %
                     status %
                     s->queryLoad % s->qps.getRate() % s->order % s->weight % s->queries.load() % s->reuseds.load() % (s->dropRate) % (s->latencyUsec/1000.0) % s->outstanding.load() % pools) << endl;
           }
@@ -1399,8 +1391,8 @@ void setupLuaConfig(bool client, bool configCheck)
             if (!servers.empty()) {
               servers += ", ";
             }
-            if (!server.second->name.empty()) {
-              servers += server.second->name;
+            if (!server.second->getName().empty()) {
+              servers += server.second->getName();
               servers += " ";
             }
             servers += server.second->remote.toStringWithPort();
@@ -1698,6 +1690,27 @@ void setupLuaConfig(bool client, bool configCheck)
 #endif /* HAVE_NET_SNMP */
     });
 
+  g_lua.writeFunction("setServerPolicy", [](ServerPolicy policy) {
+      setLuaSideEffect();
+      g_policy.setState(policy);
+    });
+
+  g_lua.writeFunction("setServerPolicyLua", [](string name, ServerPolicy::policyfunc_t policy) {
+      setLuaSideEffect();
+      g_policy.setState(ServerPolicy{name, policy, true});
+    });
+
+  g_lua.writeFunction("setServerPolicyLuaFFI", [](string name, ServerPolicy::ffipolicyfunc_t policy) {
+      setLuaSideEffect();
+      auto pol = ServerPolicy(name, policy);
+      g_policy.setState(std::move(pol));
+    });
+
+  g_lua.writeFunction("showServerPolicy", []() {
+      setLuaSideEffect();
+      g_outputBuffer=g_policy.getLocal()->name+"\n";
+    });
+
   g_lua.writeFunction("setPoolServerPolicy", [](ServerPolicy policy, string pool) {
       setLuaSideEffect();
       auto localPools = g_pools.getCopy();
@@ -1705,7 +1718,7 @@ void setupLuaConfig(bool client, bool configCheck)
       g_pools.setState(localPools);
     });
 
-  g_lua.writeFunction("setPoolServerPolicyLua", [](string name, policyfunc_t policy, string pool) {
+  g_lua.writeFunction("setPoolServerPolicyLua", [](string name, ServerPolicy::policyfunc_t policy, string pool) {
       setLuaSideEffect();
       auto localPools = g_pools.getCopy();
       setPoolPolicy(localPools, pool, std::make_shared<ServerPolicy>(ServerPolicy{name, policy, true}));
@@ -2191,6 +2204,10 @@ vector<std::function<void(void)>> setupLua(bool client, bool configCheck, const
   setupLuaRules();
   setupLuaVars();
 
+#ifdef LUAJIT_VERSION
+  g_lua.executeCode(getLuaFFIWrappers());
+#endif
+
   std::ifstream ifs(config);
   if(!ifs)
     warnlog("Unable to read configuration from '%s'", config);
index 021e1248195a259ec67a585c48917eb7e6188a16..65bfec3fdc389032185692e53c89ae5783c6fbc9 100644 (file)
@@ -30,36 +30,6 @@ struct ResponseConfig
 };
 void setResponseHeadersFromConfig(dnsheader& dh, const ResponseConfig& config);
 
-class LuaAction : public DNSAction
-{
-public:
-  typedef std::function<std::tuple<int, boost::optional<string> >(DNSQuestion* dq)> func_t;
-  LuaAction(const LuaAction::func_t& func) : d_func(func)
-  {}
-  Action operator()(DNSQuestion* dq, string* ruleresult) const override;
-  string toString() const override
-  {
-    return "Lua script";
-  }
-private:
-  func_t d_func;
-};
-
-class LuaResponseAction : public DNSResponseAction
-{
-public:
-  typedef std::function<std::tuple<int, boost::optional<string> >(DNSResponse* dr)> func_t;
-  LuaResponseAction(const LuaResponseAction::func_t& func) : d_func(func)
-  {}
-  Action operator()(DNSResponse* dr, string* ruleresult) const override;
-  string toString() const override
-  {
-    return "Lua response script";
-  }
-private:
-  func_t d_func;
-};
-
 class SpoofAction : public DNSAction
 {
 public:
index c0957d4e6f437673a8779f19383c50f151092e73..eefa558b7baea1d4a144dcbe8babdce11505ed84 100644 (file)
@@ -52,6 +52,10 @@ static const oid specialMemoryUsageOID[] = { DNSDIST_STATS_OID, 39 };
 
 static std::unordered_map<oid, DNSDistStats::entry_t> s_statsMap;
 
+bool g_snmpEnabled{false};
+bool g_snmpTrapsEnabled{false};
+DNSDistSNMPAgent* g_snmpAgent{nullptr};
+
 /* We are never called for a GETNEXT if it's registered as a
    "instance", as it's "magically" handled for us.  */
 /* a instance handler also only hands us one request at a time, so
@@ -293,8 +297,8 @@ static int backendStatTable_handler(netsnmp_mib_handler* handler,
       case COLUMN_BACKENDNAME:
         snmp_set_var_typed_value(request->requestvb,
                                  ASN_OCTET_STR,
-                                 server->name.c_str(),
-                                 server->name.size());
+                                 server->getName().c_str(),
+                                 server->getName().size());
         break;
       case COLUMN_BACKENDLATENCY:
         DNSDistSNMPAgent::setCounter64Value(request,
@@ -388,8 +392,8 @@ bool DNSDistSNMPAgent::sendBackendStatusChangeTrap(const std::shared_ptr<Downstr
                             backendNameOID,
                             OID_LENGTH(backendNameOID),
                             ASN_OCTET_STR,
-                            dss->name.c_str(),
-                            dss->name.size());
+                            dss->getName().c_str(),
+                            dss->getName().size());
 
   snmp_varlist_add_variable(&varList,
                             backendAddressOID,
index d6a04d4ec254b5b1b7e49bcb33d0fa198041cc1f..797a5730f432018f9146a0a6abf9c93357fe92ff 100644 (file)
@@ -558,7 +558,7 @@ static void connectionThread(int sock, ComboAddress remote)
         for (const auto& state : *states) {
           string serverName;
 
-          if (state->name.empty())
+          if (state->getName().empty())
               serverName = state->remote.toStringWithPort();
           else
               serverName = state->getName();
@@ -821,7 +821,7 @@ static void connectionThread(int sock, ComboAddress remote)
 
        Json::object server{
          {"id", num++},
-         {"name", a->name},
+         {"name", a->getName()},
           {"address", a->remote.toStringWithPort()},
           {"state", status},
           {"qps", (double)a->queryLoad},
index 00cb2e3dfb89210b30e786410163f3ca9977d4e4..5c5803b1c728cf6a0dacb81d5865202b11b5b7be 100644 (file)
@@ -102,10 +102,6 @@ std::vector<std::unique_ptr<ClientState>> g_frontends;
 GlobalStateHolder<pools_t> g_pools;
 size_t g_udpVectorSize{1};
 
-bool g_snmpEnabled{false};
-bool g_snmpTrapsEnabled{false};
-DNSDistSNMPAgent* g_snmpAgent{nullptr};
-
 /* UDP: the grand design. Per socket we listen on for incoming queries there is one thread.
    Then we have a bunch of connected sockets for talking to downstream servers. 
    We send directly to those sockets.
@@ -141,7 +137,6 @@ bool g_servFailOnNoPolicy{false};
 bool g_truncateTC{false};
 bool g_fixupCase{false};
 bool g_preserveTrailingData{false};
-bool g_roundrobinFailOnNoServer{false};
 
 std::set<std::string> g_capabilitiesToRetain;
 
@@ -193,6 +188,30 @@ struct DelayedPacket
 
 DelayPipe<DelayedPacket>* g_delay = nullptr;
 
+std::string DNSQuestion::getTrailingData() const
+{
+  const char* message = reinterpret_cast<const char*>(this->dh);
+  const uint16_t messageLen = getDNSPacketLength(message, this->len);
+  return std::string(message + messageLen, this->len - messageLen);
+}
+
+bool DNSQuestion::setTrailingData(const std::string& tail)
+{
+  char* message = reinterpret_cast<char*>(this->dh);
+  const uint16_t messageLen = getDNSPacketLength(message, this->len);
+  const uint16_t tailLen = tail.size();
+  if (tailLen > (this->size - messageLen)) {
+    return false;
+  }
+
+  /* Update length and copy data from the Lua string. */
+  this->len = messageLen + tailLen;
+  if(tailLen > 0) {
+    tail.copy(message + messageLen, tailLen);
+  }
+  return true;
+}
+
 void doLatencyStats(double udiff)
 {
   if(udiff < 1000) ++g_stats.latency0_1;
@@ -683,353 +702,10 @@ catch(...)
   errlog("UDP responder thread died because of an exception: %s", "unknown");
 }
 
-bool DownstreamState::reconnect()
-{
-  std::unique_lock<std::mutex> tl(connectLock, std::try_to_lock);
-  if (!tl.owns_lock()) {
-    /* we are already reconnecting */
-    return false;
-  }
-
-  connected = false;
-  for (auto& fd : sockets) {
-    if (fd != -1) {
-      if (sockets.size() > 1) {
-        std::lock_guard<std::mutex> lock(socketsLock);
-        mplexer->removeReadFD(fd);
-      }
-      /* shutdown() is needed to wake up recv() in the responderThread */
-      shutdown(fd, SHUT_RDWR);
-      close(fd);
-      fd = -1;
-    }
-    if (!IsAnyAddress(remote)) {
-      fd = SSocket(remote.sin4.sin_family, SOCK_DGRAM, 0);
-      if (!IsAnyAddress(sourceAddr)) {
-        SSetsockopt(fd, SOL_SOCKET, SO_REUSEADDR, 1);
-        if (!sourceItfName.empty()) {
-#ifdef SO_BINDTODEVICE
-          int res = setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, sourceItfName.c_str(), sourceItfName.length());
-          if (res != 0) {
-            infolog("Error setting up the interface on backend socket '%s': %s", remote.toStringWithPort(), stringerror());
-          }
-#endif
-        }
-
-        SBind(fd, sourceAddr);
-      }
-      try {
-        SConnect(fd, remote);
-        if (sockets.size() > 1) {
-          std::lock_guard<std::mutex> lock(socketsLock);
-          mplexer->addReadFD(fd, [](int, boost::any) {});
-        }
-        connected = true;
-      }
-      catch(const std::runtime_error& error) {
-        infolog("Error connecting to new server with address %s: %s", remote.toStringWithPort(), error.what());
-        connected = false;
-        break;
-      }
-    }
-  }
-
-  /* if at least one (re-)connection failed, close all sockets */
-  if (!connected) {
-    for (auto& fd : sockets) {
-      if (fd != -1) {
-        if (sockets.size() > 1) {
-          std::lock_guard<std::mutex> lock(socketsLock);
-          mplexer->removeReadFD(fd);
-        }
-        /* shutdown() is needed to wake up recv() in the responderThread */
-        shutdown(fd, SHUT_RDWR);
-        close(fd);
-        fd = -1;
-      }
-    }
-  }
-
-  return connected;
-}
-void DownstreamState::hash()
-{
-  vinfolog("Computing hashes for id=%s and weight=%d", id, weight);
-  auto w = weight;
-  WriteLock wl(&d_lock);
-  hashes.clear();
-  while (w > 0) {
-    std::string uuid = boost::str(boost::format("%s-%d") % id % w);
-    unsigned int wshash = burtleCI((const unsigned char*)uuid.c_str(), uuid.size(), g_hashperturb);
-    hashes.insert(wshash);
-    --w;
-  }
-}
-
-void DownstreamState::setId(const boost::uuids::uuid& newId)
-{
-  id = newId;
-  // compute hashes only if already done
-  if (!hashes.empty()) {
-    hash();
-  }
-}
-
-void DownstreamState::setWeight(int newWeight)
-{
-  if (newWeight < 1) {
-    errlog("Error setting server's weight: downstream weight value must be greater than 0.");
-    return ;
-  }
-  weight = newWeight;
-  if (!hashes.empty()) {
-    hash();
-  }
-}
-
-DownstreamState::DownstreamState(const ComboAddress& remote_, const ComboAddress& sourceAddr_, unsigned int sourceItf_, const std::string& sourceItfName_, size_t numberOfSockets, bool connect=true): sourceItfName(sourceItfName_), remote(remote_), sourceAddr(sourceAddr_), sourceItf(sourceItf_)
-{
-  pthread_rwlock_init(&d_lock, nullptr);
-  id = getUniqueID();
-  threadStarted.clear();
-
-  mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
-
-  sockets.resize(numberOfSockets);
-  for (auto& fd : sockets) {
-    fd = -1;
-  }
-
-  if (connect && !IsAnyAddress(remote)) {
-    reconnect();
-    idStates.resize(g_maxOutstanding);
-    sw.start();
-    infolog("Added downstream server %s", remote.toStringWithPort());
-  }
-
-}
-
 std::mutex g_luamutex;
 LuaContext g_lua;
-
-GlobalStateHolder<ServerPolicy> g_policy;
-
-shared_ptr<DownstreamState> firstAvailable(const NumberedServerVector& servers, const DNSQuestion* dq)
-{
-  for(auto& d : servers) {
-    if(d.second->isUp() && d.second->qps.check())
-      return d.second;
-  }
-  return leastOutstanding(servers, dq);
-}
-
-// get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest
-shared_ptr<DownstreamState> leastOutstanding(const NumberedServerVector& servers, const DNSQuestion* dq)
-{
-  if (servers.size() == 1 && servers[0].second->isUp()) {
-    return servers[0].second;
-  }
-
-  vector<pair<tuple<int,int,double>, shared_ptr<DownstreamState>>> poss;
-  /* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort,
-     which would suck royally and could even lead to crashes. So first we snapshot on what we sort, and then we sort */
-  poss.reserve(servers.size());
-  for(auto& d : servers) {
-    if(d.second->isUp()) {
-      poss.push_back({make_tuple(d.second->outstanding.load(), d.second->order, d.second->latencyUsec), d.second});
-    }
-  }
-  if(poss.empty())
-    return shared_ptr<DownstreamState>();
-  nth_element(poss.begin(), poss.begin(), poss.end(), [](const decltype(poss)::value_type& a, const decltype(poss)::value_type& b) { return a.first < b.first; });
-  return poss.begin()->second;
-}
-
-shared_ptr<DownstreamState> valrandom(unsigned int val, const NumberedServerVector& servers, const DNSQuestion* dq)
-{
-  vector<pair<int, shared_ptr<DownstreamState>>> poss;
-  int sum = 0;
-  int max = std::numeric_limits<int>::max();
-
-  for(auto& d : servers) {      // w=1, w=10 -> 1, 11
-    if(d.second->isUp()) {
-      // Don't overflow sum when adding high weights
-      if(d.second->weight > max - sum) {
-        sum = max;
-      } else {
-        sum += d.second->weight;
-      }
-
-      poss.push_back({sum, d.second});
-    }
-  }
-
-  // Catch poss & sum are empty to avoid SIGFPE
-  if(poss.empty())
-    return shared_ptr<DownstreamState>();
-
-  int r = val % sum;
-  auto p = upper_bound(poss.begin(), poss.end(),r, [](int r_, const decltype(poss)::value_type& a) { return  r_ < a.first;});
-  if(p==poss.end())
-    return shared_ptr<DownstreamState>();
-  return p->second;
-}
-
-shared_ptr<DownstreamState> wrandom(const NumberedServerVector& servers, const DNSQuestion* dq)
-{
-  return valrandom(random(), servers, dq);
-}
-
-uint32_t g_hashperturb;
-double g_consistentHashBalancingFactor = 0;
-shared_ptr<DownstreamState> whashed(const NumberedServerVector& servers, const DNSQuestion* dq)
-{
-  return valrandom(dq->qname->hash(g_hashperturb), servers, dq);
-}
-
-shared_ptr<DownstreamState> chashed(const NumberedServerVector& servers, const DNSQuestion* dq)
-{
-  unsigned int qhash = dq->qname->hash(g_hashperturb);
-  unsigned int sel = std::numeric_limits<unsigned int>::max();
-  unsigned int min = std::numeric_limits<unsigned int>::max();
-  shared_ptr<DownstreamState> ret = nullptr, first = nullptr;
-
-  double targetLoad = std::numeric_limits<double>::max();
-  if (g_consistentHashBalancingFactor > 0) {
-    /* we start with one, representing the query we are currently handling */
-    double currentLoad = 1;
-    for (const auto& pair : servers) {
-      currentLoad += pair.second->outstanding;
-    }
-    targetLoad = (currentLoad / servers.size()) * g_consistentHashBalancingFactor;
-  }
-
-  for (const auto& d: servers) {
-    if (d.second->isUp() && d.second->outstanding <= targetLoad) {
-      // make sure hashes have been computed
-      if (d.second->hashes.empty()) {
-        d.second->hash();
-      }
-      {
-        ReadLock rl(&(d.second->d_lock));
-        const auto& server = d.second;
-        // we want to keep track of the last hash
-        if (min > *(server->hashes.begin())) {
-          min = *(server->hashes.begin());
-          first = server;
-        }
-
-        auto hash_it = server->hashes.lower_bound(qhash);
-        if (hash_it != server->hashes.end()) {
-          if (*hash_it < sel) {
-            sel = *hash_it;
-            ret = server;
-          }
-        }
-      }
-    }
-  }
-  if (ret != nullptr) {
-    return ret;
-  }
-  if (first != nullptr) {
-    return first;
-  }
-  return shared_ptr<DownstreamState>();
-}
-
-shared_ptr<DownstreamState> roundrobin(const NumberedServerVector& servers, const DNSQuestion* dq)
-{
-  NumberedServerVector poss;
-
-  for(auto& d : servers) {
-    if(d.second->isUp()) {
-      poss.push_back(d);
-    }
-  }
-
-  const auto *res=&poss;
-  if(poss.empty() && !g_roundrobinFailOnNoServer)
-    res = &servers;
-
-  if(res->empty())
-    return shared_ptr<DownstreamState>();
-
-  static unsigned int counter;
-  return (*res)[(counter++) % res->size()].second;
-}
-
 ComboAddress g_serverControl{"127.0.0.1:5199"};
 
-std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string& poolName)
-{
-  std::shared_ptr<ServerPool> pool;
-  pools_t::iterator it = pools.find(poolName);
-  if (it != pools.end()) {
-    pool = it->second;
-  }
-  else {
-    if (!poolName.empty())
-      vinfolog("Creating pool %s", poolName);
-    pool = std::make_shared<ServerPool>();
-    pools.insert(std::pair<std::string,std::shared_ptr<ServerPool> >(poolName, pool));
-  }
-  return pool;
-}
-
-void setPoolPolicy(pools_t& pools, const string& poolName, std::shared_ptr<ServerPolicy> policy)
-{
-  std::shared_ptr<ServerPool> pool = createPoolIfNotExists(pools, poolName);
-  if (!poolName.empty()) {
-    vinfolog("Setting pool %s server selection policy to %s", poolName, policy->name);
-  } else {
-    vinfolog("Setting default pool server selection policy to %s", policy->name);
-  }
-  pool->policy = policy;
-}
-
-void addServerToPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server)
-{
-  std::shared_ptr<ServerPool> pool = createPoolIfNotExists(pools, poolName);
-  if (!poolName.empty()) {
-    vinfolog("Adding server to pool %s", poolName);
-  } else {
-    vinfolog("Adding server to default pool");
-  }
-  pool->addServer(server);
-}
-
-void removeServerFromPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server)
-{
-  std::shared_ptr<ServerPool> pool = getPool(pools, poolName);
-
-  if (!poolName.empty()) {
-    vinfolog("Removing server from pool %s", poolName);
-  }
-  else {
-    vinfolog("Removing server from default pool");
-  }
-
-  pool->removeServer(server);
-}
-
-std::shared_ptr<ServerPool> getPool(const pools_t& pools, const std::string& poolName)
-{
-  pools_t::const_iterator it = pools.find(poolName);
-
-  if (it == pools.end()) {
-    throw std::out_of_range("No pool named " + poolName);
-  }
-
-  return it->second;
-}
-
-NumberedServerVector getDownstreamCandidates(const pools_t& pools, const std::string& poolName)
-{
-  std::shared_ptr<ServerPool> pool = getPool(pools, poolName);
-  return pool->getServers();
-}
 
 static void spoofResponseFromString(DNSQuestion& dq, const string& spoofContent, bool raw)
 {
@@ -1500,13 +1176,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
       policy = *(serverPool->policy);
     }
     auto servers = serverPool->getServers();
-    if (policy.isLua) {
-      std::lock_guard<std::mutex> lock(g_luamutex);
-      selectedBackend = policy.policy(servers, &dq);
-    }
-    else {
-      selectedBackend = policy.policy(servers, &dq);
-    }
+    selectedBackend = getSelectedBackendFromPolicy(policy, servers, dq);
 
     uint16_t cachedResponseSize = dq.size;
     uint32_t allowExpired = selectedBackend ? 0 : g_staleCacheEntriesTTL;
@@ -2006,7 +1676,7 @@ static void healthChecksThread()
           --dss->outstanding;
           ++g_stats.downstreamTimeouts; // this is an 'actively' discovered timeout
           vinfolog("Had a downstream timeout from %s (%s) for query for %s|%s from %s",
-                   dss->remote.toStringWithPort(), dss->name,
+                   dss->remote.toStringWithPort(), dss->getName(),
                    ids.qname.toLogString(), QType(ids.qtype).getName(), ids.origRemote.toStringWithPort());
 
           struct timespec ts;
index 9443dc8ce943e6cea281feb07817312d4a97ecbe..99f85ee0c02c10e763fc5a069dc62b705d632d16 100644 (file)
@@ -39,6 +39,7 @@
 #include "dnscrypt.hh"
 #include "dnsdist-cache.hh"
 #include "dnsdist-dynbpf.hh"
+#include "dnsdist-lbpolicies.hh"
 #include "dnsname.hh"
 #include "doh.hh"
 #include "ednsoptions.hh"
@@ -70,6 +71,9 @@ struct DNSQuestion
   DNSQuestion& operator=(const DNSQuestion&) = delete;
   DNSQuestion(DNSQuestion&&) = default;
 
+  std::string getTrailingData() const;
+  bool setTrailingData(const std::string&);
+
 #ifdef HAVE_PROTOBUF
   boost::optional<boost::uuids::uuid> uniqueId;
 #endif
@@ -768,7 +772,7 @@ struct DownstreamState
     pthread_rwlock_destroy(&d_lock);
   }
   boost::uuids::uuid id;
-  std::set<unsigned int> hashes;
+  std::vector<unsigned int> hashes;
   mutable pthread_rwlock_t d_lock;
   std::vector<int> sockets;
   const std::string sourceItfName;
@@ -804,7 +808,6 @@ struct DownstreamState
   std::atomic<double> tcpAvgQueriesPerConnection{0.0};
   /* in ms */
   std::atomic<double> tcpAvgConnectionDuration{0.0};
-  string name;
   size_t socketsOffset{0};
   double queryLoad{0.0};
   double dropRate{0.0};
@@ -849,17 +852,17 @@ struct DownstreamState
   void setDown() { availability = Availability::Down; }
   void setAuto() { availability = Availability::Auto; }
   string getName() const {
-    if (name.empty()) {
-      return remote.toStringWithPort();
-    }
     return name;
   }
   string getNameWithAddr() const {
-    if (name.empty()) {
-      return remote.toStringWithPort();
-    }
-    return name + " (" + remote.toStringWithPort()+ ")";
+    return nameWithAddr;
+  }
+  void setName(const std::string& newName)
+  {
+    name = newName;
+    nameWithAddr = newName.empty() ? remote.toStringWithPort() : (name + " (" + remote.toStringWithPort()+ ")");
   }
+
   string getStatus() const
   {
     string status;
@@ -881,11 +884,12 @@ struct DownstreamState
     tcpAvgQueriesPerConnection = (99.0 * tcpAvgQueriesPerConnection / 100.0) + (nbQueries / 100.0);
     tcpAvgConnectionDuration = (99.0 * tcpAvgConnectionDuration / 100.0) + (durationMs / 100.0);
   }
+private:
+  std::string name;
+  std::string nameWithAddr;
 };
 using servers_t =vector<std::shared_ptr<DownstreamState>>;
 
-template <class T> using NumberedVector = std::vector<std::pair<unsigned int, T> >;
-
 void responderThread(std::shared_ptr<DownstreamState> state);
 extern std::mutex g_luamutex;
 extern LuaContext g_lua;
@@ -902,19 +906,6 @@ public:
   mutable std::atomic<uint64_t> d_matches{0};
 };
 
-using NumberedServerVector = NumberedVector<shared_ptr<DownstreamState>>;
-typedef std::function<shared_ptr<DownstreamState>(const NumberedServerVector& servers, const DNSQuestion*)> policyfunc_t;
-
-struct ServerPolicy
-{
-  string name;
-  policyfunc_t policy;
-  bool isLua;
-  std::string toString() const {
-    return string("ServerPolicy") + (isLua ? " (Lua)" : "") + " \"" + name + "\"";
-  }
-};
-
 struct ServerPool
 {
   ServerPool()
@@ -953,9 +944,9 @@ struct ServerPool
     return count;
   }
 
-  NumberedVector<shared_ptr<DownstreamState>> getServers()
+  ServerPolicy::NumberedServerVector getServers()
   {
-    NumberedVector<shared_ptr<DownstreamState>> result;
+    ServerPolicy::NumberedServerVector result;
     {
       ReadLock rl(&d_lock);
       result = d_servers;
@@ -1002,14 +993,10 @@ struct ServerPool
   }
 
 private:
-  NumberedVector<shared_ptr<DownstreamState>> d_servers;
+  ServerPolicy::NumberedServerVector d_servers;
   pthread_rwlock_t d_lock;
   bool d_useECS{false};
 };
-using pools_t=map<std::string,std::shared_ptr<ServerPool>>;
-void setPoolPolicy(pools_t& pools, const string& poolName, std::shared_ptr<ServerPolicy> policy);
-void addServerToPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server);
-void removeServerFromPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server);
 
 struct CarbonConfig
 {
@@ -1112,17 +1099,7 @@ struct LocalHolders
 struct dnsheader;
 
 void controlThread(int fd, ComboAddress local);
-std::shared_ptr<ServerPool> getPool(const pools_t& pools, const std::string& poolName);
-std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string& poolName);
-NumberedServerVector getDownstreamCandidates(const pools_t& pools, const std::string& poolName);
-
-std::shared_ptr<DownstreamState> firstAvailable(const NumberedServerVector& servers, const DNSQuestion* dq);
-
-std::shared_ptr<DownstreamState> leastOutstanding(const NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> wrandom(const NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> whashed(const NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> chashed(const NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> roundrobin(const NumberedServerVector& servers, const DNSQuestion* dq);
+vector<std::function<void(void)>> setupLua(bool client, const std::string& config);
 
 struct WebserverConfig
 {
index 5cec6c5e8f509f43c772c05c100016761b72ad24..30c50de897c677d5d2055f6273d0e0d986564ec9 100644 (file)
@@ -18,6 +18,7 @@
 /configure
 /depcomp
 /dnsdist.1
+/dnsdist-lua-ffi-interface.inc
 /dnslabeltext.cc
 /ext/ipcrypt/Makefile
 /ext/ipcrypt/Makefile.in
index 78ae0624587d64688327a9cb027f730f87b05fb5..2e0d9f9afed5ff606c440ec128759bf393d6d849 100644 (file)
@@ -10,6 +10,7 @@ CLEANFILES = \
        dnsmessage.pb.h \
        htmlfiles.h.tmp \
        htmlfiles.h \
+       dnsdist-lua-ffi-interface.inc \
        dnstap.pb.cc \
        dnstap.pb.h
 
@@ -17,12 +18,18 @@ dnslabeltext.cc: dnslabeltext.rl
        $(AM_V_GEN)$(RAGEL) $< -o dnslabeltext.cc
 
 BUILT_SOURCES=htmlfiles.h \
+       dnsdist-lua-ffi-interface.inc \
        dnslabeltext.cc
 
 htmlfiles.h: $(srcdir)/html/*
        $(AM_V_GEN)$(srcdir)/incfiles $(srcdir) > $@.tmp
        @mv $@.tmp $@
 
+dnsdist-lua-ffi-interface.inc: dnsdist-lua-ffi-interface.h
+       echo 'R"FFIContent(' > $@
+       cat $< >> $@
+       echo ')FFIContent"' >> $@
+
 SRC_JS_FILES := $(wildcard src_js/*.js)
 MIN_JS_FILES := $(patsubst src_js/%.js,html/js/%.min.js,$(SRC_JS_FILES))
 
@@ -117,6 +124,7 @@ dnsdist_SOURCES = \
        dns.cc dns.hh \
        dnscrypt.cc dnscrypt.hh \
        dnsdist.cc dnsdist.hh \
+       dnsdist-backend.cc \
        dnsdist-dynbpf.cc dnsdist-dynbpf.hh \
        dnsdist-cache.cc dnsdist-cache.hh \
        dnsdist-carbon.cc \
@@ -127,6 +135,7 @@ dnsdist_SOURCES = \
        dnsdist-healthchecks.cc dnsdist-healthchecks.hh \
        dnsdist-idstate.cc \
        dnsdist-kvs.hh dnsdist-kvs.cc \
+       dnsdist-lbpolicies.cc dnsdist-lbpolicies.hh \
        dnsdist-lua.hh dnsdist-lua.cc \
        dnsdist-lua-actions.cc \
        dnsdist-lua-bindings.cc \
@@ -135,6 +144,8 @@ dnsdist_SOURCES = \
        dnsdist-lua-bindings-kvs.cc \
        dnsdist-lua-bindings-packetcache.cc \
        dnsdist-lua-bindings-protobuf.cc \
+       dnsdist-lua-ffi.cc dnsdist-lua-ffi.hh \
+       dnsdist-lua-ffi-interface.h dnsdist-lua-ffi-interface.inc \
        dnsdist-lua-inspection.cc \
        dnsdist-lua-inspection-ffi.cc dnsdist-lua-inspection-ffi.hh \
        dnsdist-lua-rules.cc \
@@ -197,6 +208,7 @@ testrunner_SOURCES = \
        test-dnsdist_cc.cc \
        test-dnsdistdynblocks_hh.cc \
        test-dnsdistkvs_cc.cc \
+       test-dnsdistlbpolicies_cc.cc \
        test-dnsdistpacketcache_cc.cc \
        test-dnsdistrings_cc.cc \
        test-dnsdistrules_cc.cc \
@@ -206,10 +218,14 @@ testrunner_SOURCES = \
        cachecleaner.hh \
        circular_buffer.hh \
        dnsdist.hh \
+       dnsdist-backend.cc \
        dnsdist-cache.cc dnsdist-cache.hh \
        dnsdist-dynblocks.cc dnsdist-dynblocks.hh \
        dnsdist-ecs.cc dnsdist-ecs.hh \
        dnsdist-kvs.cc dnsdist-kvs.hh \
+       dnsdist-lbpolicies.cc dnsdist-lbpolicies.hh \
+       dnsdist-lua-ffi.cc dnsdist-lua-ffi.hh \
+       dnsdist-lua-ffi-interface.h dnsdist-lua-ffi-interface.inc \
        dnsdist-rings.hh \
        dnsdist-xpf.cc dnsdist-xpf.hh \
        dnscrypt.cc dnscrypt.hh \
@@ -234,7 +250,9 @@ testrunner_SOURCES = \
        statnode.cc statnode.hh \
        threadname.hh threadname.cc \
        testrunner.cc \
-       xpf.cc xpf.hh
+       uuid-utils.hh uuid-utils.cc \
+       xpf.cc xpf.hh \
+       ext/luawrapper/include/LuaContext.hpp
 
 dnsdist_LDFLAGS = \
        $(AM_LDFLAGS) \
@@ -262,8 +280,9 @@ testrunner_LDFLAGS = \
 
 testrunner_LDADD = \
        $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) \
-       $(LIBSODIUM_LIBS) \
        $(FSTRM_LIBS) \
+       $(LIBSODIUM_LIBS) \
+       $(LUA_LIBS) \
        $(RT_LIBS) \
        $(SANITIZER_FLAGS) \
        $(LIBCAP_LIBS)
diff --git a/pdns/dnsdistdist/dnsdist-backend.cc b/pdns/dnsdistdist/dnsdist-backend.cc
new file mode 100644 (file)
index 0000000..606977c
--- /dev/null
@@ -0,0 +1,151 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "dnsdist.hh"
+#include "dolog.hh"
+
+bool DownstreamState::reconnect()
+{
+  std::unique_lock<std::mutex> tl(connectLock, std::try_to_lock);
+  if (!tl.owns_lock()) {
+    /* we are already reconnecting */
+    return false;
+  }
+
+  connected = false;
+  for (auto& fd : sockets) {
+    if (fd != -1) {
+      if (sockets.size() > 1) {
+        std::lock_guard<std::mutex> lock(socketsLock);
+        mplexer->removeReadFD(fd);
+      }
+      /* shutdown() is needed to wake up recv() in the responderThread */
+      shutdown(fd, SHUT_RDWR);
+      close(fd);
+      fd = -1;
+    }
+    if (!IsAnyAddress(remote)) {
+      fd = SSocket(remote.sin4.sin_family, SOCK_DGRAM, 0);
+      if (!IsAnyAddress(sourceAddr)) {
+        SSetsockopt(fd, SOL_SOCKET, SO_REUSEADDR, 1);
+        if (!sourceItfName.empty()) {
+#ifdef SO_BINDTODEVICE
+          int res = setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, sourceItfName.c_str(), sourceItfName.length());
+          if (res != 0) {
+            infolog("Error setting up the interface on backend socket '%s': %s", remote.toStringWithPort(), stringerror());
+          }
+#endif
+        }
+
+        SBind(fd, sourceAddr);
+      }
+      try {
+        SConnect(fd, remote);
+        if (sockets.size() > 1) {
+          std::lock_guard<std::mutex> lock(socketsLock);
+          mplexer->addReadFD(fd, [](int, boost::any) {});
+        }
+        connected = true;
+      }
+      catch(const std::runtime_error& error) {
+        infolog("Error connecting to new server with address %s: %s", remote.toStringWithPort(), error.what());
+        connected = false;
+        break;
+      }
+    }
+  }
+
+  /* if at least one (re-)connection failed, close all sockets */
+  if (!connected) {
+    for (auto& fd : sockets) {
+      if (fd != -1) {
+        if (sockets.size() > 1) {
+          std::lock_guard<std::mutex> lock(socketsLock);
+          mplexer->removeReadFD(fd);
+        }
+        /* shutdown() is needed to wake up recv() in the responderThread */
+        shutdown(fd, SHUT_RDWR);
+        close(fd);
+        fd = -1;
+      }
+    }
+  }
+
+  return connected;
+}
+void DownstreamState::hash()
+{
+  vinfolog("Computing hashes for id=%s and weight=%d", id, weight);
+  auto w = weight;
+  WriteLock wl(&d_lock);
+  hashes.clear();
+  hashes.reserve(w);
+  while (w > 0) {
+    std::string uuid = boost::str(boost::format("%s-%d") % id % w);
+    unsigned int wshash = burtleCI(reinterpret_cast<const unsigned char*>(uuid.c_str()), uuid.size(), g_hashperturb);
+    hashes.push_back(wshash);
+    --w;
+  }
+  std::sort(hashes.begin(), hashes.end());
+}
+
+void DownstreamState::setId(const boost::uuids::uuid& newId)
+{
+  id = newId;
+  // compute hashes only if already done
+  if (!hashes.empty()) {
+    hash();
+  }
+}
+
+void DownstreamState::setWeight(int newWeight)
+{
+  if (newWeight < 1) {
+    errlog("Error setting server's weight: downstream weight value must be greater than 0.");
+    return ;
+  }
+  weight = newWeight;
+  if (!hashes.empty()) {
+    hash();
+  }
+}
+
+DownstreamState::DownstreamState(const ComboAddress& remote_, const ComboAddress& sourceAddr_, unsigned int sourceItf_, const std::string& sourceItfName_, size_t numberOfSockets, bool connect=true): sourceItfName(sourceItfName_), remote(remote_), sourceAddr(sourceAddr_), sourceItf(sourceItf_), name(remote_.toStringWithPort()), nameWithAddr(remote_.toStringWithPort())
+{
+  pthread_rwlock_init(&d_lock, nullptr);
+  id = getUniqueID();
+  threadStarted.clear();
+
+  mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
+
+  sockets.resize(numberOfSockets);
+  for (auto& fd : sockets) {
+    fd = -1;
+  }
+
+  if (connect && !IsAnyAddress(remote)) {
+    reconnect();
+    idStates.resize(g_maxOutstanding);
+    sw.start();
+  }
+
+}
diff --git a/pdns/dnsdistdist/dnsdist-lbpolicies.cc b/pdns/dnsdistdist/dnsdist-lbpolicies.cc
new file mode 100644 (file)
index 0000000..aa08a32
--- /dev/null
@@ -0,0 +1,289 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "dnsdist.hh"
+#include "dnsdist-lbpolicies.hh"
+#include "dnsdist-lua-ffi.hh"
+#include "dolog.hh"
+
+GlobalStateHolder<ServerPolicy> g_policy;
+bool g_roundrobinFailOnNoServer{false};
+
+// get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest
+shared_ptr<DownstreamState> leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+{
+  if (servers.size() == 1 && servers[0].second->isUp()) {
+    return servers[0].second;
+  }
+
+  vector<pair<tuple<int,int,double>, size_t>> poss;
+  /* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort,
+     which would suck royally and could even lead to crashes. So first we snapshot on what we sort, and then we sort */
+  poss.reserve(servers.size());
+  size_t position = 0;
+  for(const auto& d : servers) {
+    if(d.second->isUp()) {
+      poss.emplace_back(make_tuple(d.second->outstanding.load(), d.second->order, d.second->latencyUsec), position);
+    }
+    ++position;
+  }
+
+  if (poss.empty()) {
+    return shared_ptr<DownstreamState>();
+  }
+
+  nth_element(poss.begin(), poss.begin(), poss.end(), [](const decltype(poss)::value_type& a, const decltype(poss)::value_type& b) { return a.first < b.first; });
+  return servers.at(poss.begin()->second).second;
+}
+
+shared_ptr<DownstreamState> firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+{
+  for(auto& d : servers) {
+    if(d.second->isUp() && d.second->qps.check())
+      return d.second;
+  }
+  return leastOutstanding(servers, dq);
+}
+
+static shared_ptr<DownstreamState> valrandom(unsigned int val, const ServerPolicy::NumberedServerVector& servers)
+{
+  vector<pair<int, size_t>> poss;
+  poss.reserve(servers.size());
+  int sum = 0;
+  int max = std::numeric_limits<int>::max();
+
+  for(const auto& d : servers) {      // w=1, w=10 -> 1, 11
+    if(d.second->isUp()) {
+      // Don't overflow sum when adding high weights
+      if(d.second->weight > max - sum) {
+        sum = max;
+      } else {
+        sum += d.second->weight;
+      }
+
+      poss.emplace_back(sum, d.first);
+    }
+  }
+
+  // Catch poss & sum are empty to avoid SIGFPE
+  if (poss.empty()) {
+    return shared_ptr<DownstreamState>();
+  }
+
+  int r = val % sum;
+  auto p = upper_bound(poss.begin(), poss.end(),r, [](int r_, const decltype(poss)::value_type& a) { return  r_ < a.first;});
+  if (p == poss.end()) {
+    return shared_ptr<DownstreamState>();
+  }
+
+  return servers.at(p->second - 1).second;
+}
+
+shared_ptr<DownstreamState> wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+{
+  return valrandom(random(), servers);
+}
+
+uint32_t g_hashperturb;
+double g_consistentHashBalancingFactor = 0;
+
+shared_ptr<DownstreamState> whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash)
+{
+  return valrandom(hash, servers);
+}
+
+shared_ptr<DownstreamState> whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+{
+  return whashedFromHash(servers, dq->qname->hash(g_hashperturb));
+}
+
+shared_ptr<DownstreamState> chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t qhash)
+{
+  unsigned int sel = std::numeric_limits<unsigned int>::max();
+  unsigned int min = std::numeric_limits<unsigned int>::max();
+  shared_ptr<DownstreamState> ret = nullptr, first = nullptr;
+
+  double targetLoad = std::numeric_limits<double>::max();
+  if (g_consistentHashBalancingFactor > 0) {
+    /* we start with one, representing the query we are currently handling */
+    double currentLoad = 1;
+    for (const auto& pair : servers) {
+      currentLoad += pair.second->outstanding;
+    }
+    targetLoad = (currentLoad / servers.size()) * g_consistentHashBalancingFactor;
+  }
+
+  for (const auto& d: servers) {
+    if (d.second->isUp() && d.second->outstanding <= targetLoad) {
+      // make sure hashes have been computed
+      if (d.second->hashes.empty()) {
+        d.second->hash();
+      }
+      {
+        ReadLock rl(&(d.second->d_lock));
+        const auto& server = d.second;
+        // we want to keep track of the last hash
+        if (min > *(server->hashes.begin())) {
+          min = *(server->hashes.begin());
+          first = server;
+        }
+
+        auto hash_it = std::lower_bound(server->hashes.begin(), server->hashes.end(), qhash);
+        if (hash_it != server->hashes.end()) {
+          if (*hash_it < sel) {
+            sel = *hash_it;
+            ret = server;
+          }
+        }
+      }
+    }
+  }
+  if (ret != nullptr) {
+    return ret;
+  }
+  if (first != nullptr) {
+    return first;
+  }
+  return shared_ptr<DownstreamState>();
+}
+
+shared_ptr<DownstreamState> chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+{
+  return chashedFromHash(servers, dq->qname->hash(g_hashperturb));
+}
+
+shared_ptr<DownstreamState> roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+{
+  ServerPolicy::NumberedServerVector poss;
+
+  for(auto& d : servers) {
+    if(d.second->isUp()) {
+      poss.push_back(d);
+    }
+  }
+
+  const auto *res=&poss;
+  if(poss.empty() && !g_roundrobinFailOnNoServer)
+    res = &servers;
+
+  if(res->empty())
+    return shared_ptr<DownstreamState>();
+
+  static unsigned int counter;
+  return (*res)[(counter++) % res->size()].second;
+}
+
+ServerPolicy::NumberedServerVector getDownstreamCandidates(const pools_t& pools, const std::string& poolName)
+{
+  std::shared_ptr<ServerPool> pool = getPool(pools, poolName);
+  return pool->getServers();
+}
+
+std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string& poolName)
+{
+  std::shared_ptr<ServerPool> pool;
+  pools_t::iterator it = pools.find(poolName);
+  if (it != pools.end()) {
+    pool = it->second;
+  }
+  else {
+    if (!poolName.empty())
+      vinfolog("Creating pool %s", poolName);
+    pool = std::make_shared<ServerPool>();
+    pools.insert(std::pair<std::string,std::shared_ptr<ServerPool> >(poolName, pool));
+  }
+  return pool;
+}
+
+void setPoolPolicy(pools_t& pools, const string& poolName, std::shared_ptr<ServerPolicy> policy)
+{
+  std::shared_ptr<ServerPool> pool = createPoolIfNotExists(pools, poolName);
+  if (!poolName.empty()) {
+    vinfolog("Setting pool %s server selection policy to %s", poolName, policy->name);
+  } else {
+    vinfolog("Setting default pool server selection policy to %s", policy->name);
+  }
+  pool->policy = policy;
+}
+
+void addServerToPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server)
+{
+  std::shared_ptr<ServerPool> pool = createPoolIfNotExists(pools, poolName);
+  if (!poolName.empty()) {
+    vinfolog("Adding server to pool %s", poolName);
+  } else {
+    vinfolog("Adding server to default pool");
+  }
+  pool->addServer(server);
+}
+
+void removeServerFromPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server)
+{
+  std::shared_ptr<ServerPool> pool = getPool(pools, poolName);
+
+  if (!poolName.empty()) {
+    vinfolog("Removing server from pool %s", poolName);
+  }
+  else {
+    vinfolog("Removing server from default pool");
+  }
+
+  pool->removeServer(server);
+}
+
+std::shared_ptr<ServerPool> getPool(const pools_t& pools, const std::string& poolName)
+{
+  pools_t::const_iterator it = pools.find(poolName);
+
+  if (it == pools.end()) {
+    throw std::out_of_range("No pool named " + poolName);
+  }
+
+  return it->second;
+}
+
+std::shared_ptr<DownstreamState> getSelectedBackendFromPolicy(const ServerPolicy& policy, const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq)
+{
+  std::shared_ptr<DownstreamState> selectedBackend{nullptr};
+
+  if (policy.isLua) {
+    if (!policy.isFFI) {
+      std::lock_guard<std::mutex> lock(g_luamutex);
+      selectedBackend = policy.policy(servers, &dq);
+    }
+    else {
+      dnsdist_ffi_dnsquestion_t dnsq(&dq);
+      dnsdist_ffi_servers_list_t serversList(servers);
+      unsigned int selected = 0;
+      {
+        std::lock_guard<std::mutex> lock(g_luamutex);
+        selected = policy.ffipolicy(&serversList, &dnsq);
+      }
+      selectedBackend = servers.at(selected).second;
+    }
+  }
+  else {
+    selectedBackend = policy.policy(servers, &dq);
+  }
+
+  return selectedBackend;
+}
diff --git a/pdns/dnsdistdist/dnsdist-lbpolicies.hh b/pdns/dnsdistdist/dnsdist-lbpolicies.hh
new file mode 120000 (symlink)
index 0000000..020353f
--- /dev/null
@@ -0,0 +1 @@
+../dnsdist-lbpolicies.hh
\ No newline at end of file
index 5050879af2624ece90584bd136f704893892fcb6..a026cfb11b4d30a318c9f880423e6a1342ee1c1b 100644 (file)
@@ -64,26 +64,26 @@ void setupLuaBindingsKVS(bool client)
     }
 
     if (keyVar.type() == typeid(ComboAddress)) {
-      const auto ca = *boost::get<ComboAddress>(&keyVar);
+      const auto ca = boost::get<ComboAddress>(&keyVar);
       KeyValueLookupKeySourceIP lookup;
-      for (const auto& key : lookup.getKeys(ca)) {
+      for (const auto& key : lookup.getKeys(*ca)) {
         if (kvs->getValue(key, result)) {
           return result;
         }
       }
     }
     else if (keyVar.type() == typeid(DNSName)) {
-      DNSName dn = *boost::get<DNSName>(&keyVar);
+      const DNSName* dn = boost::get<DNSName>(&keyVar);
       KeyValueLookupKeyQName lookup(wireFormat ? *wireFormat : true);
-      for (const auto& key : lookup.getKeys(dn)) {
+      for (const auto& key : lookup.getKeys(*dn)) {
         if (kvs->getValue(key, result)) {
           return result;
         }
       }
     }
     else if (keyVar.type() == typeid(std::string)) {
-      std::string keyStr = *boost::get<std::string>(&keyVar);
-      kvs->getValue(keyStr, result);
+      const std::string* keyStr = boost::get<std::string>(&keyVar);
+      kvs->getValue(*keyStr, result);
     }
 
     return result;
diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi-interface.h b/pdns/dnsdistdist/dnsdist-lua-ffi-interface.h
new file mode 100644 (file)
index 0000000..d56b735
--- /dev/null
@@ -0,0 +1,114 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+/* we don't use a guard (C++ pragma once or even #ifndef because this file (the .inc version)
+   is passed to the Lua FFI wrapper which doesn't support it */
+
+typedef struct dnsdist_ffi_dnsquestion_t dnsdist_ffi_dnsquestion_t;
+typedef struct dnsdist_ffi_servers_list_t dnsdist_ffi_servers_list_t;
+typedef struct dnsdist_ffi_server_t dnsdist_ffi_server_t;
+
+typedef struct dnsdist_ffi_ednsoption {
+  uint16_t    optionCode;
+  uint16_t    len;
+  const void* data;
+} dnsdist_ffi_ednsoption_t;
+
+typedef struct dnsdist_ffi_http_header {
+  const char* name;
+  const char* value;
+} dnsdist_ffi_http_header_t;
+
+typedef struct dnsdist_ffi_tag {
+  const char* name;
+  const char* value;
+} dnsdist_ffi_tag_t;
+
+
+void dnsdist_ffi_dnsquestion_get_localaddr(const dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize) __attribute__ ((visibility ("default")));
+uint16_t dnsdist_ffi_dnsquestion_get_local_port(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_get_remoteaddr(const dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_get_masked_remoteaddr(dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize, uint8_t bits) __attribute__ ((visibility ("default")));
+uint16_t dnsdist_ffi_dnsquestion_get_remote_port(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_get_qname_raw(const dnsdist_ffi_dnsquestion_t* dq, const char** qname, size_t* qnameSize) __attribute__ ((visibility ("default")));
+size_t dnsdist_ffi_dnsquestion_get_qname_hash(const dnsdist_ffi_dnsquestion_t* dq, size_t init);
+uint16_t dnsdist_ffi_dnsquestion_get_qtype(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+uint16_t dnsdist_ffi_dnsquestion_get_qclass(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+int dnsdist_ffi_dnsquestion_get_rcode(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+void* dnsdist_ffi_dnsquestion_get_header(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+uint16_t dnsdist_ffi_dnsquestion_get_len(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+size_t dnsdist_ffi_dnsquestion_get_size(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+uint8_t dnsdist_ffi_dnsquestion_get_opcode(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_get_tcp(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_get_skip_cache(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_get_use_ecs(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_get_add_xpf(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_get_ecs_override(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+uint16_t dnsdist_ffi_dnsquestion_get_ecs_prefix_length(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_is_temp_failure_ttl_set(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+uint32_t dnsdist_ffi_dnsquestion_get_temp_failure_ttl(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_get_do(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_get_sni(const dnsdist_ffi_dnsquestion_t* dq, const char** sni, size_t* sniSize) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_dnsquestion_get_tag(const dnsdist_ffi_dnsquestion_t* dq, const char* label) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_dnsquestion_get_http_path(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_dnsquestion_get_http_query_string(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_dnsquestion_get_http_host(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_dnsquestion_get_http_scheme(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+
+// returns the length of the resulting 'out' array. 'out' is not set if the length is 0
+size_t dnsdist_ffi_dnsquestion_get_edns_options(dnsdist_ffi_dnsquestion_t* ref, const dnsdist_ffi_ednsoption_t** out) __attribute__ ((visibility ("default")));
+size_t dnsdist_ffi_dnsquestion_get_http_headers(dnsdist_ffi_dnsquestion_t* ref, const dnsdist_ffi_http_header_t** out) __attribute__ ((visibility ("default")));
+size_t dnsdist_ffi_dnsquestion_get_tag_array(dnsdist_ffi_dnsquestion_t* ref, const dnsdist_ffi_tag_t** out) __attribute__ ((visibility ("default")));
+
+void dnsdist_ffi_dnsquestion_set_result(dnsdist_ffi_dnsquestion_t* dq, const char* str, size_t strSize) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_rcode(dnsdist_ffi_dnsquestion_t* dq, int rcode) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_len(dnsdist_ffi_dnsquestion_t* dq, uint16_t len) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_skip_cache(dnsdist_ffi_dnsquestion_t* dq, bool skipCache) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_use_ecs(dnsdist_ffi_dnsquestion_t* dq, bool useECS) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_ecs_override(dnsdist_ffi_dnsquestion_t* dq, bool ecsOverride) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_ecs_prefix_length(dnsdist_ffi_dnsquestion_t* dq, uint16_t ecsPrefixLength) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq, uint32_t tempFailureTTL) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_unset_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_tag(dnsdist_ffi_dnsquestion_t* dq, const char* label, const char* value) __attribute__ ((visibility ("default")));
+
+void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, uint16_t statusCode, const char* body, const char* contentType) __attribute__ ((visibility ("default")));
+
+size_t dnsdist_ffi_dnsquestion_get_trailing_data(dnsdist_ffi_dnsquestion_t* dq, const char** out) __attribute__ ((visibility ("default")));
+
+bool dnsdist_ffi_dnsquestion_set_trailing_data(dnsdist_ffi_dnsquestion_t* dq, const char* data, size_t dataLen) __attribute__ ((visibility ("default")));
+
+void dnsdist_ffi_dnsquestion_send_trap(dnsdist_ffi_dnsquestion_t* dq, const char* reason, size_t reasonLen) __attribute__ ((visibility ("default")));
+
+typedef struct dnsdist_ffi_servers_list_t dnsdist_ffi_servers_list_t;
+typedef struct dnsdist_ffi_server_t dnsdist_ffi_server_t;
+
+size_t dnsdist_ffi_servers_list_get_count(const dnsdist_ffi_servers_list_t* list) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_servers_list_get_server(const dnsdist_ffi_servers_list_t* list, size_t idx, const dnsdist_ffi_server_t** out) __attribute__ ((visibility ("default")));
+size_t dnsdist_ffi_servers_list_chashed(const dnsdist_ffi_servers_list_t* list, const dnsdist_ffi_dnsquestion_t* dq, size_t hash) __attribute__ ((visibility ("default")));
+size_t dnsdist_ffi_servers_list_whashed(const dnsdist_ffi_servers_list_t* list, const dnsdist_ffi_dnsquestion_t* dq, size_t hash) __attribute__ ((visibility ("default")));
+
+uint64_t dnsdist_ffi_server_get_outstanding(const dnsdist_ffi_server_t* server) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_server_is_up(const dnsdist_ffi_server_t* server) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_server_get_name(const dnsdist_ffi_server_t* server) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_server_get_name_with_addr(const dnsdist_ffi_server_t* server) __attribute__ ((visibility ("default")));
+int dnsdist_ffi_server_get_weight(const dnsdist_ffi_server_t* server) __attribute__ ((visibility ("default")));
+int dnsdist_ffi_server_get_order(const dnsdist_ffi_server_t* server) __attribute__ ((visibility ("default")));
diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi.cc b/pdns/dnsdistdist/dnsdist-lua-ffi.cc
new file mode 100644 (file)
index 0000000..f709fea
--- /dev/null
@@ -0,0 +1,501 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "dnsdist-lua-ffi.hh"
+#include "dnsdist-ecs.hh"
+
+uint16_t dnsdist_ffi_dnsquestion_get_qtype(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->qtype;
+}
+
+uint16_t dnsdist_ffi_dnsquestion_get_qclass(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->qclass;
+}
+
+static void dnsdist_ffi_comboaddress_to_raw(const ComboAddress& ca, const void** addr, size_t* addrSize)
+{
+  if (ca.isIPv4()) {
+    *addr = &ca.sin4.sin_addr.s_addr;
+    *addrSize = sizeof(ca.sin4.sin_addr.s_addr);
+  }
+  else {
+    *addr = &ca.sin6.sin6_addr.s6_addr;
+    *addrSize = sizeof(ca.sin6.sin6_addr.s6_addr);
+  }
+}
+
+void dnsdist_ffi_dnsquestion_get_localaddr(const dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize)
+{
+  dnsdist_ffi_comboaddress_to_raw(*dq->dq->local, addr, addrSize);
+}
+
+void dnsdist_ffi_dnsquestion_get_remoteaddr(const dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize)
+{
+  dnsdist_ffi_comboaddress_to_raw(*dq->dq->remote, addr, addrSize);
+}
+
+void dnsdist_ffi_dnsquestion_get_masked_remoteaddr(dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize, uint8_t bits)
+{
+  dq->maskedRemote = Netmask(*dq->dq->remote, bits).getMaskedNetwork();
+  dnsdist_ffi_comboaddress_to_raw(dq->maskedRemote, addr, addrSize);
+}
+
+uint16_t dnsdist_ffi_dnsquestion_get_local_port(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->local->getPort();
+}
+
+uint16_t dnsdist_ffi_dnsquestion_get_remote_port(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->remote->getPort();
+}
+
+void dnsdist_ffi_dnsquestion_get_qname_raw(const dnsdist_ffi_dnsquestion_t* dq, const char** qname, size_t* qnameSize)
+{
+  const auto& storage = dq->dq->qname->getStorage();
+  *qname = storage.data();
+  *qnameSize = storage.size();
+}
+
+size_t dnsdist_ffi_dnsquestion_get_qname_hash(const dnsdist_ffi_dnsquestion_t* dq, size_t init)
+{
+  return dq->dq->qname->hash(init);
+}
+
+int dnsdist_ffi_dnsquestion_get_rcode(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->dh->rcode;
+}
+
+void* dnsdist_ffi_dnsquestion_get_header(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->dh;
+}
+
+uint16_t dnsdist_ffi_dnsquestion_get_len(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->len;
+}
+
+size_t dnsdist_ffi_dnsquestion_get_size(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->size;
+}
+
+uint8_t dnsdist_ffi_dnsquestion_get_opcode(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->dh->opcode;
+}
+
+bool dnsdist_ffi_dnsquestion_get_tcp(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->tcp;
+}
+
+bool dnsdist_ffi_dnsquestion_get_skip_cache(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->skipCache;
+}
+
+bool dnsdist_ffi_dnsquestion_get_use_ecs(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->useECS;
+}
+
+bool dnsdist_ffi_dnsquestion_get_add_xpf(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->addXPF;
+}
+
+bool dnsdist_ffi_dnsquestion_get_ecs_override(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->ecsOverride;
+}
+
+uint16_t dnsdist_ffi_dnsquestion_get_ecs_prefix_length(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->ecsPrefixLength;
+}
+
+bool dnsdist_ffi_dnsquestion_is_temp_failure_ttl_set(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return dq->dq->tempFailureTTL != boost::none;
+}
+
+uint32_t dnsdist_ffi_dnsquestion_get_temp_failure_ttl(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  if (dq->dq->tempFailureTTL) {
+    return *dq->dq->tempFailureTTL;
+  }
+  return 0;
+}
+
+bool dnsdist_ffi_dnsquestion_get_do(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  return getEDNSZ(*dq->dq) & EDNS_HEADER_FLAG_DO;
+}
+
+void dnsdist_ffi_dnsquestion_get_sni(const dnsdist_ffi_dnsquestion_t* dq, const char** sni, size_t* sniSize)
+{
+  *sniSize = dq->dq->sni.size();
+  *sni = dq->dq->sni.c_str();
+}
+
+const char* dnsdist_ffi_dnsquestion_get_tag(const dnsdist_ffi_dnsquestion_t* dq, const char* label)
+{
+  const char * result = nullptr;
+
+  if (dq->dq->qTag != nullptr) {
+    const auto it = dq->dq->qTag->find(label);
+    if (it != dq->dq->qTag->cend()) {
+      result = it->second.c_str();
+    }
+  }
+
+  return result;
+}
+
+const char* dnsdist_ffi_dnsquestion_get_http_path(dnsdist_ffi_dnsquestion_t* dq)
+{
+  if (!dq->httpPath) {
+    if (dq->dq->du == nullptr) {
+      return nullptr;
+    }
+#ifdef HAVE_DNS_OVER_HTTPS
+    dq->httpPath = dq->dq->du->getHTTPPath();
+#endif /* HAVE_DNS_OVER_HTTPS */
+  }
+  if (dq->httpPath) {
+    return dq->httpPath->c_str();
+  }
+  return nullptr;
+}
+
+const char* dnsdist_ffi_dnsquestion_get_http_query_string(dnsdist_ffi_dnsquestion_t* dq)
+{
+  if (!dq->httpQueryString) {
+    if (dq->dq->du == nullptr) {
+      return nullptr;
+    }
+#ifdef HAVE_DNS_OVER_HTTPS
+    dq->httpQueryString = dq->dq->du->getHTTPQueryString();
+#endif /* HAVE_DNS_OVER_HTTPS */
+  }
+  if (dq->httpQueryString) {
+    return dq->httpQueryString->c_str();
+  }
+  return nullptr;
+}
+
+const char* dnsdist_ffi_dnsquestion_get_http_host(dnsdist_ffi_dnsquestion_t* dq)
+{
+  if (!dq->httpHost) {
+    if (dq->dq->du == nullptr) {
+      return nullptr;
+    }
+#ifdef HAVE_DNS_OVER_HTTPS
+    dq->httpHost = dq->dq->du->getHTTPHost();
+#endif /* HAVE_DNS_OVER_HTTPS */
+  }
+  if (dq->httpHost) {
+    return dq->httpHost->c_str();
+  }
+  return nullptr;
+}
+
+const char* dnsdist_ffi_dnsquestion_get_http_scheme(dnsdist_ffi_dnsquestion_t* dq)
+{
+  if (!dq->httpScheme) {
+    if (dq->dq->du == nullptr) {
+      return nullptr;
+    }
+#ifdef HAVE_DNS_OVER_HTTPS
+    dq->httpScheme = dq->dq->du->getHTTPScheme();
+#endif /* HAVE_DNS_OVER_HTTPS */
+  }
+  if (dq->httpScheme) {
+    return dq->httpScheme->c_str();
+  }
+  return nullptr;
+}
+
+static void fill_edns_option(const EDNSOptionViewValue& value, dnsdist_ffi_ednsoption_t& option)
+{
+  option.len = value.size;
+  option.data = nullptr;
+
+  if (value.size > 0) {
+    option.data = value.content;
+  }
+}
+
+// returns the length of the resulting 'out' array. 'out' is not set if the length is 0
+size_t dnsdist_ffi_dnsquestion_get_edns_options(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_ednsoption_t** out)
+{
+  if (dq->dq->ednsOptions == nullptr) {
+    parseEDNSOptions(*(dq->dq));
+  }
+
+  size_t totalCount = 0;
+  for (const auto& option : *dq->dq->ednsOptions) {
+    totalCount += option.second.values.size();
+  }
+
+  dq->ednsOptionsVect.clear();
+  dq->ednsOptionsVect.resize(totalCount);
+  size_t pos = 0;
+  for (const auto& option : *dq->dq->ednsOptions) {
+    for (const auto& entry : option.second.values) {
+      fill_edns_option(entry, dq->ednsOptionsVect.at(pos));
+      dq->ednsOptionsVect.at(pos).optionCode = option.first;
+      pos++;
+    }
+  }
+
+  if (totalCount > 0) {
+    *out = dq->ednsOptionsVect.data();
+  }
+
+  return totalCount;
+}
+
+size_t dnsdist_ffi_dnsquestion_get_http_headers(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_http_header_t** out)
+{
+  if (dq->dq->du == nullptr) {
+    return 0;
+  }
+
+#ifdef HAVE_DNS_OVER_HTTPS
+  dq->httpHeaders = dq->dq->du->getHTTPHeaders();
+  dq->httpHeadersVect.clear();
+  dq->httpHeadersVect.resize(dq->httpHeaders.size());
+  size_t pos = 0;
+  for (const auto& header : dq->httpHeaders) {
+    dq->httpHeadersVect.at(pos).name = header.first.c_str();
+    dq->httpHeadersVect.at(pos).value = header.second.c_str();
+    ++pos;
+  }
+
+  if (!dq->httpHeadersVect.empty()) {
+    *out = dq->httpHeadersVect.data();
+  }
+
+  return dq->httpHeadersVect.size();
+#else
+  return 0;
+#endif
+}
+
+size_t dnsdist_ffi_dnsquestion_get_tag_array(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_tag_t** out)
+{
+  if (dq->dq->qTag == nullptr || dq->dq->qTag->size() == 0) {
+    return 0;
+  }
+
+  dq->tagsVect.clear();
+  dq->tagsVect.resize(dq->dq->qTag->size());
+  size_t pos = 0;
+
+  for (const auto& tag : *dq->dq->qTag) {
+    auto& entry = dq->tagsVect.at(pos);
+    entry.name = tag.first.c_str();
+    entry.value = tag.second.c_str();
+    ++pos;
+  }
+
+
+  if (!dq->tagsVect.empty()) {
+    *out = dq->tagsVect.data();
+  }
+
+  return dq->tagsVect.size();
+}
+
+void dnsdist_ffi_dnsquestion_set_result(dnsdist_ffi_dnsquestion_t* dq, const char* str, size_t strSize)
+{
+  dq->result = std::string(str, strSize);
+}
+
+void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, uint16_t statusCode, const char* body, const char* contentType)
+{
+  if (dq->dq->du == nullptr) {
+    return;
+  }
+
+#ifdef HAVE_DNS_OVER_HTTPS
+  dq->dq->du->setHTTPResponse(statusCode, body, contentType);
+  dq->dq->dh->qr = true;
+#endif
+}
+
+void dnsdist_ffi_dnsquestion_set_rcode(dnsdist_ffi_dnsquestion_t* dq, int rcode)
+{
+  dq->dq->dh->rcode = rcode;
+  dq->dq->dh->qr = true;
+}
+
+void dnsdist_ffi_dnsquestion_set_len(dnsdist_ffi_dnsquestion_t* dq, uint16_t len)
+{
+  dq->dq->len = len;
+}
+
+void dnsdist_ffi_dnsquestion_set_skip_cache(dnsdist_ffi_dnsquestion_t* dq, bool skipCache)
+{
+  dq->dq->skipCache = skipCache;
+}
+
+void dnsdist_ffi_dnsquestion_set_use_ecs(dnsdist_ffi_dnsquestion_t* dq, bool useECS)
+{
+  dq->dq->useECS = useECS;
+}
+
+void dnsdist_ffi_dnsquestion_set_ecs_override(dnsdist_ffi_dnsquestion_t* dq, bool ecsOverride)
+{
+  dq->dq->ecsOverride = ecsOverride;
+}
+
+void dnsdist_ffi_dnsquestion_set_ecs_prefix_length(dnsdist_ffi_dnsquestion_t* dq, uint16_t ecsPrefixLength)
+{
+  dq->dq->ecsPrefixLength = ecsPrefixLength;
+}
+
+void dnsdist_ffi_dnsquestion_set_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq, uint32_t tempFailureTTL)
+{
+  dq->dq->tempFailureTTL = tempFailureTTL;
+}
+
+void dnsdist_ffi_dnsquestion_unset_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq)
+{
+  dq->dq->tempFailureTTL = boost::none;
+}
+
+void dnsdist_ffi_dnsquestion_set_tag(dnsdist_ffi_dnsquestion_t* dq, const char* label, const char* value)
+{
+  if (!dq->dq->qTag) {
+    dq->dq->qTag = std::make_shared<QTag>();
+  }
+
+  dq->dq->qTag->insert({label, value});
+}
+
+size_t dnsdist_ffi_dnsquestion_get_trailing_data(dnsdist_ffi_dnsquestion_t* dq, const char** out)
+{
+  dq->trailingData = dq->dq->getTrailingData();
+  if (!dq->trailingData.empty()) {
+    *out = dq->trailingData.data();
+  }
+
+  return dq->trailingData.size();
+}
+
+bool dnsdist_ffi_dnsquestion_set_trailing_data(dnsdist_ffi_dnsquestion_t* dq, const char* data, size_t dataLen)
+{
+  return dq->dq->setTrailingData(std::string(data, dataLen));
+}
+
+void dnsdist_ffi_dnsquestion_send_trap(dnsdist_ffi_dnsquestion_t* dq, const char* reason, size_t reasonLen)
+{
+  if (g_snmpAgent && g_snmpTrapsEnabled) {
+    g_snmpAgent->sendDNSTrap(*dq->dq, std::string(reason, reasonLen));
+  }
+}
+
+size_t dnsdist_ffi_servers_list_get_count(const dnsdist_ffi_servers_list_t* list)
+{
+  return list->ffiServers.size();
+}
+
+void dnsdist_ffi_servers_list_get_server(const dnsdist_ffi_servers_list_t* list, size_t idx, const dnsdist_ffi_server_t** out)
+{
+  *out = &list->ffiServers.at(idx);
+}
+
+static size_t dnsdist_ffi_servers_get_index_from_server(const ServerPolicy::NumberedServerVector& servers, const std::shared_ptr<DownstreamState>& server)
+{
+  for (const auto& pair : servers) {
+    if (pair.second == server) {
+      return pair.first - 1;
+    }
+  }
+  throw std::runtime_error("Unable to find servers in server list");
+}
+
+size_t dnsdist_ffi_servers_list_chashed(const dnsdist_ffi_servers_list_t* list, const dnsdist_ffi_dnsquestion_t* dq, size_t hash)
+{
+  auto server = chashedFromHash(list->servers, hash);
+  return dnsdist_ffi_servers_get_index_from_server(list->servers, server);
+}
+
+size_t dnsdist_ffi_servers_list_whashed(const dnsdist_ffi_servers_list_t* list, const dnsdist_ffi_dnsquestion_t* dq, size_t hash)
+{
+  auto server = whashedFromHash(list->servers, hash);
+  return dnsdist_ffi_servers_get_index_from_server(list->servers, server);
+}
+
+uint64_t dnsdist_ffi_server_get_outstanding(const dnsdist_ffi_server_t* server)
+{
+  return server->server->outstanding;
+}
+
+int dnsdist_ffi_server_get_weight(const dnsdist_ffi_server_t* server)
+{
+  return server->server->weight;
+}
+
+int dnsdist_ffi_server_get_order(const dnsdist_ffi_server_t* server)
+{
+  return server->server->order;
+}
+
+bool dnsdist_ffi_server_is_up(const dnsdist_ffi_server_t* server)
+{
+  return server->server->isUp();
+}
+
+const char* dnsdist_ffi_server_get_name(const dnsdist_ffi_server_t* server)
+{
+  return server->server->getName().c_str();
+}
+
+const char* dnsdist_ffi_server_get_name_with_addr(const dnsdist_ffi_server_t* server)
+{
+  return server->server->getNameWithAddr().c_str();
+}
+
+const std::string& getLuaFFIWrappers()
+{
+  static const std::string interface =
+#include "dnsdist-lua-ffi-interface.inc"
+    ;
+  static const std::string code = R"FFICodeContent(
+  local ffi = require("ffi")
+  local C = ffi.C
+
+  ffi.cdef[[
+)FFICodeContent" + interface + R"FFICodeContent(
+  ]]
+
+)FFICodeContent";
+  return code;
+}
diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi.hh b/pdns/dnsdistdist/dnsdist-lua-ffi.hh
new file mode 100644 (file)
index 0000000..63156db
--- /dev/null
@@ -0,0 +1,109 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include "dnsdist.hh"
+
+extern "C" {
+#include "dnsdist-lua-ffi-interface.h"
+}
+
+// dnsdist_ffi_dnsquestion_t is a lightuserdata
+template<>
+struct LuaContext::Pusher<dnsdist_ffi_dnsquestion_t*> {
+    static const int minSize = 1;
+    static const int maxSize = 1;
+
+    static PushedObject push(lua_State* state, dnsdist_ffi_dnsquestion_t* ptr) noexcept {
+        lua_pushlightuserdata(state, ptr);
+        return PushedObject{state, 1};
+    }
+};
+
+struct dnsdist_ffi_dnsquestion_t
+{
+  dnsdist_ffi_dnsquestion_t(DNSQuestion* dq_): dq(dq_)
+  {
+  }
+
+  DNSQuestion* dq{nullptr};
+  std::vector<dnsdist_ffi_ednsoption_t> ednsOptionsVect;
+  std::vector<dnsdist_ffi_http_header_t> httpHeadersVect;
+  std::vector<dnsdist_ffi_tag_t> tagsVect;
+  std::unordered_map<std::string, std::string> httpHeaders;
+  std::string trailingData;
+  ComboAddress maskedRemote;
+  boost::optional<std::string> result{boost::none};
+  boost::optional<std::string> httpPath{boost::none};
+  boost::optional<std::string> httpQueryString{boost::none};
+  boost::optional<std::string> httpHost{boost::none};
+  boost::optional<std::string> httpScheme{boost::none};
+};
+
+// dnsdist_ffi_server_t is a lightuserdata
+template<>
+struct LuaContext::Pusher<dnsdist_ffi_server_t*> {
+    static const int minSize = 1;
+    static const int maxSize = 1;
+
+    static PushedObject push(lua_State* state, dnsdist_ffi_server_t* ptr) noexcept {
+        lua_pushlightuserdata(state, ptr);
+        return PushedObject{state, 1};
+    }
+};
+
+struct dnsdist_ffi_server_t
+{
+  dnsdist_ffi_server_t(const std::shared_ptr<DownstreamState>& server_): server(server_)
+  {
+  }
+
+  const std::shared_ptr<DownstreamState>& server;
+};
+
+// dnsdist_ffi_servers_list_t is a lightuserdata
+template<>
+struct LuaContext::Pusher<dnsdist_ffi_servers_list_t*> {
+    static const int minSize = 1;
+    static const int maxSize = 1;
+
+    static PushedObject push(lua_State* state, dnsdist_ffi_servers_list_t* ptr) noexcept {
+        lua_pushlightuserdata(state, ptr);
+        return PushedObject{state, 1};
+    }
+};
+
+struct dnsdist_ffi_servers_list_t
+{
+  dnsdist_ffi_servers_list_t(const ServerPolicy::NumberedServerVector& servers_): servers(servers_)
+  {
+    ffiServers.reserve(servers.size());
+    for (const auto& server: servers) {
+      ffiServers.push_back(dnsdist_ffi_server_t(server.second));
+    }
+  }
+
+  std::vector<dnsdist_ffi_server_t> ffiServers;
+  const ServerPolicy::NumberedServerVector& servers;
+};
+
+const std::string& getLuaFFIWrappers();
index 912e665488ec7c9a32bf1b98a3c4eea1a838d34f..0474a1e1504b8d3c9acee069d7ee8426b40141ee 100644 (file)
@@ -25,6 +25,8 @@
 #include "dnsdist.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsdist-kvs.hh"
+#include "dnsdist-lua-ffi.hh"
+#include "dolog.hh"
 #include "dnsparser.hh"
 
 class MaxQPSIPRule : public DNSRule
@@ -1124,3 +1126,60 @@ private:
   std::shared_ptr<KeyValueStore> d_kvs;
   std::shared_ptr<KeyValueLookupKey> d_key;
 };
+
+class LuaRule : public DNSRule
+{
+public:
+  typedef std::function<bool(const DNSQuestion* dq)> func_t;
+  LuaRule(const func_t& func): d_func(func)
+  {}
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    try {
+      std::lock_guard<std::mutex> lock(g_luamutex);
+      return d_func(dq);
+    } catch (const std::exception &e) {
+      warnlog("LuaRule failed inside Lua: %s", e.what());
+    } catch (...) {
+      warnlog("LuaRule failed inside Lua: [unknown exception]");
+    }
+    return false;
+  }
+
+  string toString() const override
+  {
+    return "Lua script";
+  }
+private:
+  func_t d_func;
+};
+
+class LuaFFIRule : public DNSRule
+{
+public:
+  typedef std::function<bool(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+  LuaFFIRule(const func_t& func): d_func(func)
+  {}
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    dnsdist_ffi_dnsquestion_t dqffi(const_cast<DNSQuestion*>(dq));
+    try {
+      std::lock_guard<std::mutex> lock(g_luamutex);
+      return d_func(&dqffi);
+    } catch (const std::exception &e) {
+      warnlog("LuaFFIRule failed inside Lua: %s", e.what());
+    } catch (...) {
+      warnlog("LuaFFIRule failed inside Lua: [unknown exception]");
+    }
+    return false;
+  }
+
+  string toString() const override
+  {
+    return "Lua FFI script";
+  }
+private:
+  func_t d_func;
+};
index 4fbbd171b703537924583e27d89ec3ad66e9e8d6..46f846c68ca31f044b9d0df1e6470314d98a64b4 100644 (file)
@@ -31,7 +31,7 @@ Large installations are advised to increase the default value at the cost of a s
 Most of the query processing is done in C++ for maximum performance, but some operations are executed in Lua for maximum flexibility:
 
  * Rules added by :func:`addLuaAction`
- * Server selection policies defined via :func:`setServerPolicyLua` or :func:`newServerPolicy`
+ * Server selection policies defined via :func:`setServerPolicyLua`, :func:`setServerPolicyLuaFFI` or :func:`newServerPolicy`
 
 While Lua is fast, its use should be restricted to the strict necessary in order to achieve maximum performance, it might be worth considering using LuaJIT instead of Lua.
 When Lua inspection is needed, the best course of action is to restrict the queries sent to Lua inspection by using :func:`addLuaAction` with a selector.
index f86dc6d8523fb2373f5bfd34434950dd6f74acbe..47c4fa55bb6cea552bed3a18c73c2acb1b71aa83 100644 (file)
@@ -114,6 +114,35 @@ ServerPolicy Objects
   :param servers: A list of :class:`Server` objects
   :param DNSQuestion dq: The incoming query
 
+  .. attribute:: ServerPolicy.ffipolicy
+
+    .. versionadded: 1.5.0
+
+    For policies implemented using the Lua FFI interface, the policy function itself.
+
+  .. attribute:: ServerPolicy.isFFI
+
+    .. versionadded: 1.5.0
+
+    Whether a Lua-based policy is implemented using the FFI interface.
+
+  .. attribute:: ServerPolicy.isLua
+
+    Whether this policy is a native (C++) policy or a Lua-based one.
+
+  .. attribute:: ServerPolicy.name
+
+    The name of the policy.
+
+  .. attribute:: ServerPolicy.policy
+
+    The policy function itself, except for FFI policies.
+
+  .. method:: Server:toString()
+
+    Return a textual representation of the policy.
+
+
 Functions
 ---------
 
@@ -141,11 +170,20 @@ Functions
 
 .. function:: setServerPolicyLua(name, function)
 
-  Set server selection policy to one named `name`` and provided by ``function``.
+  Set server selection policy to one named ``name`` and provided by ``function``.
 
   :param string name: name for this policy
   :param string function: name of the function
 
+.. function:: setServerPolicyLuaFFI(name, function)
+
+  .. versionadded:: 1.5.0
+
+  Set server selection policy to one named ``name`` and provided by the FFI function ``function``.
+
+  :param string name: name for this policy
+  :param string function: name of the FFI function
+
 .. function:: setServFailWhenNoServer(value)
 
   If set, return a ServFail when no servers are available, instead of the default behaviour of dropping the query.
index a0cdeb9baba262c84b1c2b72c4d8df0cc76f4ec0..cbd1de0d53d82c266c09e866e2711cbb1c99900d 100644 (file)
@@ -18,6 +18,7 @@ Within dnsdist several core object types exist:
 * :class:`Server`: generated with :func:`newServer`, represents a downstream server
 * :class:`ComboAddress`: represents an IP address and port
 * :class:`DNSName`: represents a domain name
+* :class:`Netmask`: represents a netmask
 * :class:`NetmaskGroup`: represents a group of netmasks
 * :class:`QPSLimiter`: implements a QPS-based filter
 * :class:`SuffixMatchNode`: represents a group of domain suffixes for rapid testing of membership
diff --git a/pdns/dnsdistdist/docs/reference/netmask.rst b/pdns/dnsdistdist/docs/reference/netmask.rst
new file mode 100644 (file)
index 0000000..89a6630
--- /dev/null
@@ -0,0 +1,53 @@
+Netmask
+=======
+
+.. function:: newNetmask(str) -> Netmask
+              newNetmask(ca, bits) -> Netmask
+
+  .. versionadded:: 1.5.0
+
+  Returns a Netmask
+
+  :param string str: A netmask, like ``192.0.2.0/24``.
+  :param ComboAddress ca: A :class:`ComboAddress`.
+  :param int bits: The number of bits in this netmask.
+
+.. class:: Netmask
+
+  .. versionadded:: 1.5.0
+
+   Represents a netmask.
+
+  .. method:: Netmask:getBits() -> int
+
+    Return the number of bits of this netmask, for example ``24`` for ``192.0.2.0/24``.
+
+  .. method:: Netmask:getMaskedNetwork() -> ComboAddress
+
+    Return a :class:`ComboAddress` object representing the base network of this netmask object after masking any additional bits if necessary (for example ``192.0.2.0`` if the netmask was constructed with ``newNetmask('192.0.2.1/24')).
+
+  .. method:: Netmask:empty() -> bool
+
+    Return true if the netmask is empty, meaning that the netmask has not been set to a proper value.
+
+  .. method:: Netmask:isIPv4() -> bool
+
+    Return true if the netmask is an IPv4 one.
+
+  .. method:: Netmask:isIPv6() -> bool
+
+    Return true if the netmask is an IPv6 one.
+
+  .. method:: Netmask:getNetwork() -> ComboAddress
+
+    Return a :class:`ComboAddress` object representing the base network of this netmask object.
+
+  .. method:: Netmask:match(str) -> bool
+
+    Return true if the address passed in the ``str`` parameter belongs to this netmask.
+
+    :param string str: A network address, like ``192.0.2.0``.
+
+  .. method:: Netmask:toString() -> string
+
+    Return a string representation of the netmask, for example ``192.0.2.0/24``.
index 9ff60543c8993cabce57cacca0d018e660abec18..13e4ac7b8dc61b418b3ca911ec10b9310d2d1f1d 100644 (file)
@@ -173,7 +173,7 @@ Rule Generators
 
   ::
 
-    function luarule(dq)
+    function luaaction(dq)
       if(dq.qtype==DNSQType.NAPTR)
       then
         return DNSAction.Pool, "abuse" -- send to abuse pool
@@ -183,7 +183,7 @@ Rule Generators
       end
     end
 
-    addLuaAction(AllRule(), luarule)
+    addLuaAction(AllRule(), luaaction)
 
 .. function:: addLuaResponseAction(DNSrule, function [, options])
 
@@ -616,6 +616,26 @@ These ``DNSRule``\ s be one of the following items:
   :param KeyValueStore kvs: The key value store to query
   :param KeyValueLookupKey lookupKey: The key to use for the lookup
 
+.. function:: LuaFFIRule(function)
+
+  .. versionadded:: 1.5.0
+
+  Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``.
+
+  The ``function`` should return true if the query matches, or false otherwise. If the Lua code fails, false is returned.
+
+  :param string function: the name of a Lua function
+
+.. function:: LuaRule(function)
+
+  .. versionadded:: 1.5.0
+
+  Invoke a Lua function that accepts a :class:`DNSQuestion` object.
+
+  The ``function`` should return true if the query matches, or false otherwise. If the Lua code fails, false is returned.
+
+  :param string function: the name of a Lua function
+
 .. function:: MaxQPSIPRule(qps[, v4Mask[, v6Mask[, burst[, expiration[, cleanupDelay[, scanFraction]]]]]])
 
   .. versionchanged:: 1.3.1
@@ -1065,6 +1085,26 @@ The following actions exist.
 
   :param string function: the name of a Lua function
 
+.. function:: LuaFFIAction(function)
+
+  .. versionadded:: 1.5.0
+
+  Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``.
+
+  The ``function`` should return a :ref:`DNSAction`. If the Lua code fails, ServFail is returned.
+
+  :param string function: the name of a Lua function
+
+.. function:: LuaFFIResponseAction(function)
+
+  .. versionadded:: 1.5.0
+
+  Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``.
+
+  The ``function`` should return a :ref:`DNSResponseAction`. If the Lua code fails, ServFail is returned.
+
+  :param string function: the name of a Lua function
+
 .. function:: LuaResponseAction(function)
 
   Invoke a Lua function that accepts a :class:`DNSResponse`.
diff --git a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc
new file mode 100644 (file)
index 0000000..4f858d7
--- /dev/null
@@ -0,0 +1,732 @@
+
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+
+#include <boost/test/unit_test.hpp>
+
+#include "dnsdist.hh"
+#include "dnsdist-lua-ffi.hh"
+#include "dolog.hh"
+
+uint16_t g_maxOutstanding{std::numeric_limits<uint16_t>::max()};
+
+std::mutex g_luamutex;
+#include "ext/luawrapper/include/LuaContext.hpp"
+LuaContext g_lua;
+
+bool g_snmpEnabled{false};
+bool g_snmpTrapsEnabled{false};
+DNSDistSNMPAgent* g_snmpAgent{nullptr};
+
+#if BENCH_POLICIES
+bool g_verbose{true};
+bool g_syslog{true};
+#include "dnsdist-rings.hh"
+Rings g_rings;
+GlobalStateHolder<NetmaskTree<DynBlock>> g_dynblockNMG;
+GlobalStateHolder<SuffixMatchTree<DynBlock>> g_dynblockSMT;
+#endif /* BENCH_POLICIES */
+
+/* add stub implementations, we don't want to include the corresponding object files
+   and their dependencies */
+
+#ifdef HAVE_DNS_OVER_HTTPS
+std::unordered_map<std::string, std::string> DOHUnit::getHTTPHeaders() const
+{
+  return {};
+}
+
+std::string DOHUnit::getHTTPPath() const
+{
+  return "";
+}
+
+std::string DOHUnit::getHTTPHost() const
+{
+  return "";
+}
+
+std::string DOHUnit::getHTTPScheme() const
+{
+  return "";
+}
+
+std::string DOHUnit::getHTTPQueryString() const
+{
+  return "";
+}
+
+void DOHUnit::setHTTPResponse(uint16_t statusCode, const std::string& body_, const std::string& contentType_)
+{
+}
+#endif /* HAVE_DNS_OVER_HTTPS */
+
+std::string DNSQuestion::getTrailingData() const
+{
+  return "";
+}
+
+bool DNSQuestion::setTrailingData(const std::string& tail)
+{
+  return false;
+}
+
+bool DNSDistSNMPAgent::sendDNSTrap(const DNSQuestion& dq, const std::string& reason)
+{
+  return false;
+}
+
+static DNSQuestion getDQ(const DNSName* providedName = nullptr)
+{
+  static const DNSName qname("powerdns.com.");
+  static const ComboAddress lc("127.0.0.1:53");
+  static const ComboAddress rem("192.0.2.1:42");
+  static struct timespec queryRealTime;
+  static struct dnsheader dh;
+
+  memset(&dh, 0, sizeof(dh));
+  uint16_t qtype = QType::A;
+  uint16_t qclass = QClass::IN;
+  size_t bufferSize = 0;
+  size_t queryLen = 0;
+  bool isTcp = false;
+  gettime(&queryRealTime, true);
+
+  DNSQuestion dq(providedName ? providedName : &qname, qtype, qclass, qname.wirelength(), &lc, &rem, &dh, bufferSize, queryLen, isTcp, &queryRealTime);
+  return dq;
+}
+
+static void benchPolicy(const ServerPolicy& pol)
+{
+#if BENCH_POLICIES
+  bool existingVerboseValue = g_verbose;
+  g_verbose = false;
+
+  std::vector<DNSName> names;
+  names.reserve(1000);
+  for (size_t idx = 0; idx < 1000; idx++) {
+    names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+  }
+  ServerPolicy::NumberedServerVector servers;
+  for (size_t idx = 1; idx <= 10; idx++) {
+    servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+    servers.at(idx - 1).second->setUp();
+    /* we need to have a weight of at least 1000 to get an optimal repartition with the consistent hashing algo */
+    servers.at(idx - 1).second->setWeight(1000);
+    /* make sure that the hashes have been computed */
+    servers.at(idx - 1).second->hash();
+  }
+
+  StopWatch sw;
+  sw.start();
+  for (size_t idx = 0; idx < 1000; idx++) {
+  for (const auto& name : names) {
+    auto dq = getDQ(&name);
+    auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+  }
+  }
+  cerr<<pol.name<<" took "<<std::to_string(sw.udiff())<<" us for "<<names.size()<<endl;
+
+  g_verbose = existingVerboseValue;
+#endif /* BENCH_POLICIES */
+}
+
+static void resetLuaContext()
+{
+  /* we need to reset this before cleaning the Lua state because the server policy might holds
+     a reference to a Lua function (Lua policies) */
+  g_policy.setState(ServerPolicy("leastOutstanding", leastOutstanding, false));
+  g_lua = LuaContext();
+}
+
+BOOST_AUTO_TEST_SUITE(dnsdistlbpolicies)
+
+BOOST_AUTO_TEST_CASE(test_firstAvailable) {
+  auto dq = getDQ();
+
+  ServerPolicy pol{"firstAvailable", firstAvailable, false};
+  ServerPolicy::NumberedServerVector servers;
+  servers.push_back({ 1, std::make_shared<DownstreamState>(ComboAddress("192.0.2.1:53")) });
+
+  /* servers start as 'down' */
+  auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+  BOOST_CHECK(server == nullptr);
+
+  /* mark the server as 'up' */
+  servers.at(0).second->setUp();
+  server = getSelectedBackendFromPolicy(pol, servers, dq);
+  BOOST_CHECK(server != nullptr);
+
+  /* add a second server, we should still get the first one */
+  servers.push_back({ 2, std::make_shared<DownstreamState>(ComboAddress("192.0.2.2:53")) });
+  server = getSelectedBackendFromPolicy(pol, servers, dq);
+  BOOST_REQUIRE(server != nullptr);
+  BOOST_CHECK(server == servers.at(0).second);
+
+  /* mark the first server as 'down', second as 'up' */
+  servers.at(0).second->setDown();
+  servers.at(1).second->setUp();
+  server = getSelectedBackendFromPolicy(pol, servers, dq);
+  BOOST_REQUIRE(server != nullptr);
+  BOOST_CHECK(server == servers.at(1).second);
+
+  std::vector<DNSName> names;
+  names.reserve(1000);
+  for (size_t idx = 0; idx < 1000; idx++) {
+    names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+  }
+  std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+  for (size_t idx = 1; idx <= 10; idx++) {
+    servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+    serversMap[servers.at(idx - 1).second] = 0;
+    servers.at(idx - 1).second->setUp();
+  }
+
+  benchPolicy(pol);
+}
+
+BOOST_AUTO_TEST_CASE(test_leastOutstanding) {
+  auto dq = getDQ();
+
+  ServerPolicy pol{"leastOutstanding", leastOutstanding, false};
+  ServerPolicy::NumberedServerVector servers;
+  servers.push_back({ 1, std::make_shared<DownstreamState>(ComboAddress("192.0.2.1:53")) });
+
+  /* servers start as 'down' */
+  auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+  BOOST_CHECK(server == nullptr);
+
+  /* mark the server as 'up' */
+  servers.at(0).second->setUp();
+  server = getSelectedBackendFromPolicy(pol, servers, dq);
+  BOOST_CHECK(server != nullptr);
+
+  /* add a second server, we should still get the first one */
+  servers.push_back({ 2, std::make_shared<DownstreamState>(ComboAddress("192.0.2.2:53")) });
+  server = getSelectedBackendFromPolicy(pol, servers, dq);
+  BOOST_REQUIRE(server != nullptr);
+  BOOST_CHECK(server == servers.at(0).second);
+
+  /* mark the first server as 'down', second as 'up' */
+  servers.at(0).second->setDown();
+  servers.at(1).second->setUp();
+  server = getSelectedBackendFromPolicy(pol, servers, dq);
+  BOOST_REQUIRE(server != nullptr);
+  BOOST_CHECK(server == servers.at(1).second);
+
+  /* mark both servers as 'up', increase the outstanding count of the first one */
+  servers.at(0).second->setUp();
+  servers.at(0).second->outstanding = 42;
+  servers.at(1).second->setUp();
+  server = getSelectedBackendFromPolicy(pol, servers, dq);
+  BOOST_REQUIRE(server != nullptr);
+  BOOST_CHECK(server == servers.at(1).second);
+
+  benchPolicy(pol);
+}
+
+BOOST_AUTO_TEST_CASE(test_wrandom) {
+  auto dq = getDQ();
+
+  ServerPolicy pol{"wrandom", wrandom, false};
+  ServerPolicy::NumberedServerVector servers;
+  std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+  for (size_t idx = 1; idx <= 10; idx++) {
+    servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+    serversMap[servers.at(idx - 1).second] = 0;
+    servers.at(idx - 1).second->setUp();
+  }
+
+  benchPolicy(pol);
+
+  for (size_t idx = 0; idx < 1000; idx++) {
+    auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+    BOOST_REQUIRE(serversMap.count(server) == 1);
+    ++serversMap[server];
+  }
+  uint64_t total = 0;
+  for (const auto& entry : serversMap) {
+    BOOST_CHECK_GT(entry.second, 0);
+    BOOST_CHECK_GT(entry.second, (1000 / servers.size() / 2));
+    BOOST_CHECK_LT(entry.second, (1000 / servers.size() * 2));
+    total += entry.second;
+  }
+  BOOST_CHECK_EQUAL(total, 1000);
+
+  /* reset */
+  for (auto& entry : serversMap) {
+    entry.second = 0;
+    BOOST_CHECK_EQUAL(entry.first->weight, 1);
+  }
+
+  /* reset */
+  for (auto& entry : serversMap) {
+    entry.second = 0;
+    BOOST_CHECK_EQUAL(entry.first->weight, 1);
+  }
+  /* change the weight of the last server to 100, default is 1 */
+  servers.at(servers.size()-1).second->weight = 100;
+
+  for (size_t idx = 0; idx < 1000; idx++) {
+    auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+    BOOST_REQUIRE(serversMap.count(server) == 1);
+    ++serversMap[server];
+  }
+
+  total = 0;
+  uint64_t totalW = 0;
+  for (const auto& entry : serversMap) {
+    total += entry.second;
+    totalW += entry.first->weight;
+  }
+  BOOST_CHECK_EQUAL(total, 1000);
+  auto last = servers.at(servers.size()-1).second;
+  const auto got = serversMap[last];
+  float expected = (1000 * 1.0 * last->weight) / totalW;
+  BOOST_CHECK_GT(got, expected / 2);
+  BOOST_CHECK_LT(got, expected * 2);
+}
+
+BOOST_AUTO_TEST_CASE(test_whashed) {
+  std::vector<DNSName> names;
+  names.reserve(1000);
+  for (size_t idx = 0; idx < 1000; idx++) {
+    names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+  }
+
+  ServerPolicy pol{"whashed", whashed, false};
+  ServerPolicy::NumberedServerVector servers;
+  std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+  for (size_t idx = 1; idx <= 10; idx++) {
+    servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+    serversMap[servers.at(idx - 1).second] = 0;
+    servers.at(idx - 1).second->setUp();
+  }
+
+  benchPolicy(pol);
+
+  for (const auto& name : names) {
+    auto dq = getDQ(&name);
+    auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+    BOOST_REQUIRE(serversMap.count(server) == 1);
+    ++serversMap[server];
+  }
+
+  uint64_t total = 0;
+  for (const auto& entry : serversMap) {
+    BOOST_CHECK_GT(entry.second, 0);
+    BOOST_CHECK_GT(entry.second, (names.size() / servers.size() / 2));
+    BOOST_CHECK_LT(entry.second, (names.size() / servers.size() * 2));
+    total += entry.second;
+  }
+  BOOST_CHECK_EQUAL(total, names.size());
+
+  /* reset */
+  for (auto& entry : serversMap) {
+    entry.second = 0;
+    BOOST_CHECK_EQUAL(entry.first->weight, 1);
+  }
+
+  /* request 1000 times the same name, we should go to the same server every time */
+  {
+    auto dq = getDQ(&names.at(0));
+    auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+    for (size_t idx = 0; idx < 1000; idx++) {
+      BOOST_CHECK(getSelectedBackendFromPolicy(pol, servers, dq) == server);
+    }
+  }
+
+  /* reset */
+  for (auto& entry : serversMap) {
+    entry.second = 0;
+    BOOST_CHECK_EQUAL(entry.first->weight, 1);
+  }
+  /* change the weight of the last server to 100, default is 1 */
+  servers.at(servers.size()-1).second->setWeight(100);
+
+  for (const auto& name : names) {
+    auto dq = getDQ(&name);
+    auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+    BOOST_REQUIRE(serversMap.count(server) == 1);
+    ++serversMap[server];
+  }
+
+  total = 0;
+  uint64_t totalW = 0;
+  for (const auto& entry : serversMap) {
+    total += entry.second;
+    totalW += entry.first->weight;
+  }
+  BOOST_CHECK_EQUAL(total, names.size());
+  auto last = servers.at(servers.size()-1).second;
+  const auto got = serversMap[last];
+  float expected = (names.size() * 1.0 * last->weight) / totalW;
+  BOOST_CHECK_GT(got, expected / 2);
+  BOOST_CHECK_LT(got, expected * 2);
+}
+
+BOOST_AUTO_TEST_CASE(test_chashed) {
+  bool existingVerboseValue = g_verbose;
+  g_verbose = false;
+
+  std::vector<DNSName> names;
+  names.reserve(1000);
+  for (size_t idx = 0; idx < 1000; idx++) {
+    names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+  }
+
+  ServerPolicy pol{"chashed", chashed, false};
+  ServerPolicy::NumberedServerVector servers;
+  std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+  for (size_t idx = 1; idx <= 10; idx++) {
+    servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+    serversMap[servers.at(idx - 1).second] = 0;
+    servers.at(idx - 1).second->setUp();
+    /* we need to have a weight of at least 1000 to get an optimal repartition with the consistent hashing algo */
+    servers.at(idx - 1).second->setWeight(1000);
+    /* make sure that the hashes have been computed */
+    servers.at(idx - 1).second->hash();
+  }
+
+  benchPolicy(pol);
+
+  for (const auto& name : names) {
+    auto dq = getDQ(&name);
+    auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+    BOOST_REQUIRE(serversMap.count(server) == 1);
+    ++serversMap[server];
+  }
+
+  uint64_t total = 0;
+  for (const auto& entry : serversMap) {
+    BOOST_CHECK_GT(entry.second, 0);
+    BOOST_CHECK_GT(entry.second, (names.size() / servers.size() / 2));
+    BOOST_CHECK_LT(entry.second, (names.size() / servers.size() * 2));
+    total += entry.second;
+  }
+  BOOST_CHECK_EQUAL(total, names.size());
+
+  /* reset */
+  for (auto& entry : serversMap) {
+    entry.second = 0;
+    BOOST_CHECK_EQUAL(entry.first->weight, 1000);
+  }
+
+  /* request 1000 times the same name, we should go to the same server every time */
+  {
+    auto dq = getDQ(&names.at(0));
+    auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+    for (size_t idx = 0; idx < 1000; idx++) {
+      BOOST_CHECK(getSelectedBackendFromPolicy(pol, servers, dq) == server);
+    }
+  }
+
+  /* reset */
+  for (auto& entry : serversMap) {
+    entry.second = 0;
+    BOOST_CHECK_EQUAL(entry.first->weight, 1000);
+  }
+  /* change the weight of the last server to 100000, others stay at 1000 */
+  servers.at(servers.size()-1).second->setWeight(100000);
+
+  for (const auto& name : names) {
+    auto dq = getDQ(&name);
+    auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+    BOOST_REQUIRE(serversMap.count(server) == 1);
+    ++serversMap[server];
+  }
+
+  total = 0;
+  uint64_t totalW = 0;
+  for (const auto& entry : serversMap) {
+    total += entry.second;
+    totalW += entry.first->weight;
+  }
+  BOOST_CHECK_EQUAL(total, names.size());
+  auto last = servers.at(servers.size()-1).second;
+  const auto got = serversMap[last];
+  float expected = (names.size() * 1.0 * last->weight) / totalW;
+  BOOST_CHECK_GT(got, expected / 2);
+  BOOST_CHECK_LT(got, expected * 2);
+
+  g_verbose = existingVerboseValue;
+}
+
+BOOST_AUTO_TEST_CASE(test_lua) {
+  std::vector<DNSName> names;
+  names.reserve(1000);
+  for (size_t idx = 0; idx < 1000; idx++) {
+    names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+  }
+
+  static const std::string policySetupStr = R"foo(
+    local counter = 0
+    function luaroundrobin(servers, dq)
+      counter = counter + 1
+      return servers[1 + (counter % #servers)]
+    end
+
+    setServerPolicyLua("luaroundrobin", luaroundrobin)
+  )foo";
+  resetLuaContext();
+  g_lua.writeFunction("setServerPolicyLua", [](string name, ServerPolicy::policyfunc_t policy) {
+      g_policy.setState(ServerPolicy{name, policy, true});
+    });
+  g_lua.executeCode(policySetupStr);
+
+  ServerPolicy pol = g_policy.getCopy();
+  ServerPolicy::NumberedServerVector servers;
+  std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+  for (size_t idx = 1; idx <= 10; idx++) {
+    servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+    serversMap[servers.at(idx - 1).second] = 0;
+    servers.at(idx - 1).second->setUp();
+  }
+  BOOST_REQUIRE_EQUAL(servers.size(), 10);
+
+  for (const auto& name : names) {
+    auto dq = getDQ(&name);
+    auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+    BOOST_REQUIRE(serversMap.count(server) == 1);
+    ++serversMap[server];
+  }
+
+  uint64_t total = 0;
+  for (const auto& entry : serversMap) {
+    BOOST_CHECK_GT(entry.second, 0);
+    BOOST_CHECK_GT(entry.second, (names.size() / servers.size() / 2));
+    BOOST_CHECK_LT(entry.second, (names.size() / servers.size() * 2));
+    total += entry.second;
+  }
+  BOOST_CHECK_EQUAL(total, names.size());
+
+  benchPolicy(pol);
+}
+
+#ifdef LUAJIT_VERSION
+
+BOOST_AUTO_TEST_CASE(test_lua_ffi_rr) {
+  std::vector<DNSName> names;
+  names.reserve(1000);
+  for (size_t idx = 0; idx < 1000; idx++) {
+    names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+  }
+
+  static const std::string policySetupStr = R"foo(
+    local ffi = require("ffi")
+    local C = ffi.C
+    local counter = 0
+    function ffilb(servers_list, dq)
+      local serversCount = tonumber(C.dnsdist_ffi_servers_list_get_count(servers_list))
+      counter = counter + 1
+      return counter % serversCount
+    end
+
+    setServerPolicyLuaFFI("FFI round-robin", ffilb)
+  )foo";
+  resetLuaContext();
+  g_lua.executeCode(getLuaFFIWrappers());
+  g_lua.writeFunction("setServerPolicyLuaFFI", [](string name, ServerPolicy::ffipolicyfunc_t policy) {
+      g_policy.setState(ServerPolicy(name, policy));
+    });
+  g_lua.executeCode(policySetupStr);
+
+  ServerPolicy pol = g_policy.getCopy();
+  ServerPolicy::NumberedServerVector servers;
+  std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+  for (size_t idx = 1; idx <= 10; idx++) {
+    servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+    serversMap[servers.at(idx - 1).second] = 0;
+    servers.at(idx - 1).second->setUp();
+  }
+  BOOST_REQUIRE_EQUAL(servers.size(), 10);
+
+  for (const auto& name : names) {
+    auto dq = getDQ(&name);
+    auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+    BOOST_REQUIRE(serversMap.count(server) == 1);
+    ++serversMap[server];
+  }
+
+  uint64_t total = 0;
+  for (const auto& entry : serversMap) {
+    BOOST_CHECK_GT(entry.second, 0);
+    BOOST_CHECK_GT(entry.second, (names.size() / servers.size() / 2));
+    BOOST_CHECK_LT(entry.second, (names.size() / servers.size() * 2));
+    total += entry.second;
+  }
+  BOOST_CHECK_EQUAL(total, names.size());
+
+  benchPolicy(pol);
+}
+
+BOOST_AUTO_TEST_CASE(test_lua_ffi_hashed) {
+  std::vector<DNSName> names;
+  names.reserve(1000);
+  for (size_t idx = 0; idx < 1000; idx++) {
+    names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+  }
+
+  static const std::string policySetupStr = R"foo(
+    local ffi = require("ffi")
+    local C = ffi.C
+    function ffilb(servers_list, dq)
+      local serversCount = tonumber(C.dnsdist_ffi_servers_list_get_count(servers_list))
+      local hash = tonumber(C.dnsdist_ffi_dnsquestion_get_qname_hash(dq, 0))
+      return hash % serversCount
+    end
+
+    setServerPolicyLuaFFI("FFI hashed", ffilb)
+  )foo";
+  resetLuaContext();
+  g_lua.executeCode(getLuaFFIWrappers());
+  g_lua.writeFunction("setServerPolicyLuaFFI", [](string name, ServerPolicy::ffipolicyfunc_t policy) {
+      g_policy.setState(ServerPolicy(name, policy));
+    });
+  g_lua.executeCode(policySetupStr);
+
+  ServerPolicy pol = g_policy.getCopy();
+  ServerPolicy::NumberedServerVector servers;
+  std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+  for (size_t idx = 1; idx <= 10; idx++) {
+    servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+    serversMap[servers.at(idx - 1).second] = 0;
+    servers.at(idx - 1).second->setUp();
+  }
+  BOOST_REQUIRE_EQUAL(servers.size(), 10);
+
+  for (const auto& name : names) {
+    auto dq = getDQ(&name);
+    auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+    BOOST_REQUIRE(serversMap.count(server) == 1);
+    ++serversMap[server];
+  }
+
+  uint64_t total = 0;
+  for (const auto& entry : serversMap) {
+    BOOST_CHECK_GT(entry.second, 0);
+    BOOST_CHECK_GT(entry.second, (names.size() / servers.size() / 2));
+    BOOST_CHECK_LT(entry.second, (names.size() / servers.size() * 2));
+    total += entry.second;
+  }
+  BOOST_CHECK_EQUAL(total, names.size());
+
+  benchPolicy(pol);
+}
+
+BOOST_AUTO_TEST_CASE(test_lua_ffi_whashed) {
+  std::vector<DNSName> names;
+  names.reserve(1000);
+  for (size_t idx = 0; idx < 1000; idx++) {
+    names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+  }
+
+  static const std::string policySetupStr = R"foo(
+    local ffi = require("ffi")
+    local C = ffi.C
+    function ffilb(servers_list, dq)
+      return tonumber(C.dnsdist_ffi_servers_list_whashed(servers_list, dq, C.dnsdist_ffi_dnsquestion_get_qname_hash(dq, 0)))
+    end
+
+    setServerPolicyLuaFFI("FFI whashed", ffilb)
+  )foo";
+  resetLuaContext();
+  g_lua.executeCode(getLuaFFIWrappers());
+  g_lua.writeFunction("setServerPolicyLuaFFI", [](string name, ServerPolicy::ffipolicyfunc_t policy) {
+      g_policy.setState(ServerPolicy(name, policy));
+    });
+  g_lua.executeCode(policySetupStr);
+
+  ServerPolicy pol = g_policy.getCopy();
+  ServerPolicy::NumberedServerVector servers;
+  std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+  for (size_t idx = 1; idx <= 10; idx++) {
+    servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+    serversMap[servers.at(idx - 1).second] = 0;
+    servers.at(idx - 1).second->setUp();
+  }
+  BOOST_REQUIRE_EQUAL(servers.size(), 10);
+
+  for (const auto& name : names) {
+    auto dq = getDQ(&name);
+    auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+    BOOST_REQUIRE(serversMap.count(server) == 1);
+    ++serversMap[server];
+  }
+
+  uint64_t total = 0;
+  for (const auto& entry : serversMap) {
+    BOOST_CHECK_GT(entry.second, 0);
+    BOOST_CHECK_GT(entry.second, (names.size() / servers.size() / 2));
+    BOOST_CHECK_LT(entry.second, (names.size() / servers.size() * 2));
+    total += entry.second;
+  }
+  BOOST_CHECK_EQUAL(total, names.size());
+
+  benchPolicy(pol);
+}
+
+BOOST_AUTO_TEST_CASE(test_lua_ffi_chashed) {
+  bool existingVerboseValue = g_verbose;
+  g_verbose = false;
+
+  std::vector<DNSName> names;
+  names.reserve(1000);
+  for (size_t idx = 0; idx < 1000; idx++) {
+    names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+  }
+
+  static const std::string policySetupStr = R"foo(
+    local ffi = require("ffi")
+    local C = ffi.C
+    function ffilb(servers_list, dq)
+      return tonumber(C.dnsdist_ffi_servers_list_chashed(servers_list, dq, C.dnsdist_ffi_dnsquestion_get_qname_hash(dq, 0)))
+    end
+
+    setServerPolicyLuaFFI("FFI chashed", ffilb)
+  )foo";
+  resetLuaContext();
+  g_lua.executeCode(getLuaFFIWrappers());
+  g_lua.writeFunction("setServerPolicyLuaFFI", [](string name, ServerPolicy::ffipolicyfunc_t policy) {
+      g_policy.setState(ServerPolicy(name, policy));
+    });
+  g_lua.executeCode(policySetupStr);
+
+  ServerPolicy pol = g_policy.getCopy();
+  ServerPolicy::NumberedServerVector servers;
+  std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+  for (size_t idx = 1; idx <= 10; idx++) {
+    servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+    serversMap[servers.at(idx - 1).second] = 0;
+    servers.at(idx - 1).second->setUp();
+    /* we need to have a weight of at least 1000 to get an optimal repartition with the consistent hashing algo */
+    servers.at(idx - 1).second->setWeight(1000);
+    /* make sure that the hashes have been computed */
+    servers.at(idx - 1).second->hash();
+  }
+  BOOST_REQUIRE_EQUAL(servers.size(), 10);
+
+  for (const auto& name : names) {
+    auto dq = getDQ(&name);
+    auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+    BOOST_REQUIRE(serversMap.count(server) == 1);
+    ++serversMap[server];
+  }
+
+  uint64_t total = 0;
+  for (const auto& entry : serversMap) {
+    BOOST_CHECK_GT(entry.second, 0);
+    BOOST_CHECK_GT(entry.second, (names.size() / servers.size() / 2));
+    BOOST_CHECK_LT(entry.second, (names.size() / servers.size() * 2));
+    total += entry.second;
+  }
+  BOOST_CHECK_EQUAL(total, names.size());
+
+  benchPolicy(pol);
+
+  g_verbose = existingVerboseValue;
+}
+
+#endif /* LUAJIT_VERSION */
+
+BOOST_AUTO_TEST_SUITE_END()
index ba1a67ac09a17ee83145236c3f30e15ce8afc775..e1b9f21aa1062a195dce2c15999cb567a1510972 100644 (file)
@@ -1892,3 +1892,193 @@ class TestAdvancedSetNegativeAndSOA(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertEquals(receivedResponse, expectedResponse)
+
+class TestAdvancedLuaRule(DNSDistTest):
+
+    _config_template = """
+
+    function luarulefunction(dq)
+      if dq.qname:toString() == 'lua-rule.advanced.tests.powerdns.com.' then
+        return true
+      end
+      return false
+    end
+
+    addAction(LuaRule(luarulefunction), RCodeAction(DNSRCode.NOTIMP))
+    addAction(AllRule(), RCodeAction(DNSRCode.REFUSED))
+    -- newServer{address="127.0.0.1:%s"}
+    """
+
+    def testAdvancedLuaRule(self):
+        """
+        Advanced: Test the LuaRule rule
+        """
+        name = 'lua-rule.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        notimplResponse = dns.message.make_response(query)
+        notimplResponse.set_rcode(dns.rcode.NOTIMP)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, notimplResponse)
+
+        name = 'not-lua-rule.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        refusedResponse = dns.message.make_response(query)
+        refusedResponse.set_rcode(dns.rcode.REFUSED)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, refusedResponse)
+
+class TestAdvancedLuaFFI(DNSDistTest):
+
+    _config_template = """
+    local ffi = require("ffi")
+
+    local expectingUDP = true
+
+    function luaffirulefunction(dq)
+      local qtype = ffi.C.dnsdist_ffi_dnsquestion_get_qtype(dq)
+      if qtype ~= DNSQType.A and qtype ~= DNSQType.SOA then
+        print('invalid qtype')
+        return false
+      end
+
+      local qclass = ffi.C.dnsdist_ffi_dnsquestion_get_qclass(dq)
+      if qclass ~= DNSClass.IN then
+        print('invalid qclass')
+        return false
+      end
+
+      local ret_ptr = ffi.new("char *[1]")
+      local ret_ptr_param = ffi.cast("const char **", ret_ptr)
+      local ret_size = ffi.new("size_t[1]")
+      local ret_size_param = ffi.cast("size_t*", ret_size)
+      ffi.C.dnsdist_ffi_dnsquestion_get_qname_raw(dq, ret_ptr_param, ret_size_param)
+      if ret_size[0] ~= 36 then
+        print('invalid length for the qname ')
+        print(ret_size[0])
+        return false
+      end
+
+      local expectedQname = string.char(6)..'luaffi'..string.char(8)..'advanced'..string.char(5)..'tests'..string.char(8)..'powerdns'..string.char(3)..'com'
+      if ffi.string(ret_ptr[0]) ~= expectedQname then
+        print('invalid qname')
+        print(ffi.string(ret_ptr[0]))
+        return false
+      end
+
+      local rcode = ffi.C.dnsdist_ffi_dnsquestion_get_rcode(dq)
+      if rcode ~= 0 then
+        print('invalid rcode')
+        return false
+      end
+
+      local opcode = ffi.C.dnsdist_ffi_dnsquestion_get_opcode(dq)
+      if qtype == DNSQType.A and opcode ~= DNSOpcode.Query then
+        print('invalid opcode')
+        return false
+      elseif qtype == DNSQType.SOA and opcode ~= DNSOpcode.Update then
+        print('invalid opcode')
+        return false
+      end
+
+      local tcp = ffi.C.dnsdist_ffi_dnsquestion_get_tcp(dq)
+      if expectingUDP == tcp then
+        print('invalid tcp')
+        return false
+      end
+      expectingUDP = expectingUDP == false
+
+      local dnssecok = ffi.C.dnsdist_ffi_dnsquestion_get_do(dq)
+      if dnssecok ~= false then
+        print('invalid DNSSEC OK')
+        return false
+      end
+
+      local len = ffi.C.dnsdist_ffi_dnsquestion_get_len(dq)
+      if len ~= 52 then
+        print('invalid length')
+        print(len)
+        return false
+      end
+
+      local tag = ffi.C.dnsdist_ffi_dnsquestion_get_tag(dq, 'a-tag')
+      if ffi.string(tag) ~= 'a-value' then
+        print('invalid tag value')
+        print(ffi.string(tag))
+        return false
+      end
+      return true
+    end
+
+    function luaffiactionfunction(dq)
+      local qtype = ffi.C.dnsdist_ffi_dnsquestion_get_qtype(dq)
+      if qtype == DNSQType.A then
+        local str = "192.0.2.1"
+        local buf = ffi.new("char[?]", #str + 1)
+        ffi.copy(buf, str)
+        ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+        return DNSAction.Spoof
+      elseif qtype == DNSQType.SOA then
+        ffi.C.dnsdist_ffi_dnsquestion_set_rcode(dq, DNSRCode.REFUSED)
+        return DNSAction.Refused
+      end
+    end
+
+    function luaffiactionsettag(dq)
+      ffi.C.dnsdist_ffi_dnsquestion_set_tag(dq, 'a-tag', 'a-value')
+      return DNSAction.None
+    end
+
+    addAction(AllRule(), LuaFFIAction(luaffiactionsettag))
+    addAction(LuaFFIRule(luaffirulefunction), LuaFFIAction(luaffiactionfunction))
+    -- newServer{address="127.0.0.1:%s"}
+    """
+
+    def testAdvancedLuaFFI(self):
+        """
+        Advanced: Test the Lua FFI interface
+        """
+        name = 'luaffi.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, response)
+
+    def testAdvancedLuaFFIUpdate(self):
+        """
+        Advanced: Test the Lua FFI interface via an update
+        """
+        name = 'luaffi.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'SOA', 'IN')
+        query.set_opcode(dns.opcode.UPDATE)
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+
+        response = dns.message.make_response(query)
+        response.set_rcode(dns.rcode.REFUSED)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, response)
index d934421a02495037fffcc87584f70659e546079b..ec60470bdaeccf4f33bec16561de30b48a2602e2 100644 (file)
@@ -10,8 +10,6 @@ from dnsdisttests import DNSDistTest
 
 import pycurl
 from io import BytesIO
-#from hyper import HTTP20Connection
-#from hyper.ssl_compat import SSLContext, PROTOCOL_TLSv1_2
 
 @unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled')
 class DNSDistDOHTest(DNSDistTest):
@@ -140,36 +138,6 @@ class DNSDistDOHTest(DNSDistTest):
 
         print("Launching tests..")
 
-#     @classmethod
-#     def openDOHConnection(cls, port, caFile, timeout=2.0):
-#         sslctx = SSLContext(PROTOCOL_TLSv1_2)
-#         sslctx.load_verify_locations(caFile)
-#         return HTTP20Connection('127.0.0.1', port=port, secure=True, timeout=timeout, ssl_context=sslctx, force_proto='h2')
-
-#     @classmethod
-#     def sendDOHQueryOverConnection(cls, conn, baseurl, query, response=None, timeout=2.0):
-#         url = cls.getDOHGetURL(baseurl, query)
-
-#         if response:
-#             cls._toResponderQueue.put(response, True, timeout)
-
-#         conn.request('GET', url)
-
-#     @classmethod
-#     def recvDOHResponseOverConnection(cls, conn, useQueue=False, timeout=2.0):
-#         message = None
-#         data = conn.get_response()
-#         if data:
-#             data = data.read()
-#             if data:
-#                 message = dns.message.from_wire(data)
-
-#         if useQueue and not cls._fromResponderQueue.empty():
-#             receivedQuery = cls._fromResponderQueue.get(True, timeout)
-#             return (receivedQuery, message)
-#         else:
-#             return message
-
 class TestDOH(DNSDistDOHTest):
 
     _serverKey = 'server.key'
@@ -920,3 +888,65 @@ class TestDOHWithoutCacheControl(DNSDistDOHTest):
         self.checkNoHeader('cache-control')
         self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
         self.assertEquals(response, receivedResponse)
+
+class TestDOHFFI(DNSDistDOHTest):
+
+    _serverKey = 'server.key'
+    _serverCert = 'server.chain'
+    _serverName = 'tls.tests.dnsdist.org'
+    _caCert = 'ca.pem'
+    _dohServerPort = 8443
+    _customResponseHeader1 = 'access-control-allow-origin: *'
+    _customResponseHeader2 = 'user-agent: derp'
+    _dohBaseURL = ("https://%s:%d/" % (_serverName, _dohServerPort))
+    _config_template = """
+    newServer{address="127.0.0.1:%s"}
+
+    addDOHLocal("127.0.0.1:%s", "%s", "%s", { "/" }, {customResponseHeaders={["access-control-allow-origin"]="*",["user-agent"]="derp",["UPPERCASE"]="VaLuE"}})
+
+    local ffi = require("ffi")
+
+    function dohHandler(dq)
+      local scheme = ffi.string(ffi.C.dnsdist_ffi_dnsquestion_get_http_scheme(dq))
+      local host = ffi.string(ffi.C.dnsdist_ffi_dnsquestion_get_http_host(dq))
+      local path = ffi.string(ffi.C.dnsdist_ffi_dnsquestion_get_http_path(dq))
+      local query_string = ffi.string(ffi.C.dnsdist_ffi_dnsquestion_get_http_query_string(dq))
+      if scheme == 'https' and host == '%s:%d' and path == '/' and query_string == '' then
+        local foundct = false
+        local headers_ptr = ffi.new("const dnsdist_ffi_http_header_t *[1]")
+        local headers_ptr_param = ffi.cast("const dnsdist_ffi_http_header_t **", headers_ptr)
+
+        local headers_count = tonumber(ffi.C.dnsdist_ffi_dnsquestion_get_http_headers(dq, headers_ptr_param))
+        if headers_count > 0 then
+          for idx = 0, headers_count-1 do
+            if ffi.string(headers_ptr[0][idx].name) == 'content-type' and ffi.string(headers_ptr[0][idx].value) == 'application/dns-message' then
+              foundct = true
+              break
+            end
+          end
+        end
+        if foundct then
+          ffi.C.dnsdist_ffi_dnsquestion_set_http_response(dq, 200, 'It works!', 'text/plain')
+          return DNSAction.HeaderModify
+        end
+      end
+      return DNSAction.None
+    end
+    addAction("http-lua-ffi.doh.tests.powerdns.com.", LuaFFIAction(dohHandler))
+    """
+    _config_params = ['_testServerPort', '_dohServerPort', '_serverCert', '_serverKey', '_serverName', '_dohServerPort']
+
+    def testHTTPLuaFFIResponse(self):
+        """
+        DOH: Lua FFI HTTP Response
+        """
+        name = 'http-lua-ffi.doh.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        query.id = 0
+
+        (_, receivedResponse) = self.sendDOHPostQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, caFile=self._caCert, useQueue=False, rawResponse=True)
+        self.assertTrue(receivedResponse)
+        self.assertEquals(receivedResponse, b'It works!')
+        self.assertEquals(self._rcode, 200)
+        self.assertTrue('content-type: text/plain' in self._response_headers.decode())
+
index 0315ed589dbc4ac32e60ea32ec2ae8bf1e28e8a0..f1e386e82546fc40a81ab819f370e54903f972e4 100644 (file)
@@ -384,3 +384,288 @@ class TestEDNSOptionsAddingECS(EDNSOptionsBase):
             receivedQuery.id = query.id
             self.assertEquals(receivedQuery, query)
             self.assertEquals(receivedResponse, response)
+
+class TestEDNSOptionsLuaFFI(DNSDistTest):
+
+    _config_template = """
+    local ffi = require("ffi")
+
+    function testEDNSOptions(dq)
+      local options_ptr = ffi.new("const dnsdist_ffi_ednsoption_t *[1]")
+      local ret_ptr_param = ffi.cast("const dnsdist_ffi_ednsoption_t **", options_ptr)
+
+      local options_count = tonumber(ffi.C.dnsdist_ffi_dnsquestion_get_edns_options(dq, ret_ptr_param))
+
+      local qname_ptr = ffi.new("const char *[1]")
+      local qname_ptr_param = ffi.cast("const char **", qname_ptr)
+      local qname_size = ffi.new("size_t[1]")
+      local qname_size_param = ffi.cast("size_t*", qname_size)
+      ffi.C.dnsdist_ffi_dnsquestion_get_qname_raw(dq, qname_ptr_param, qname_size_param)
+      local qname = ffi.string(qname_ptr[0])
+
+      if string.match(qname, 'noedns') then
+        if options_count ~= 0 then
+          local str = "192.0.2.255"
+          local buf = ffi.new("char[?]", #str + 1)
+          ffi.copy(buf, str)
+          ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+          return DNSAction.Spoof
+        end
+      end
+
+      local cookies_count = 0
+      local ecs_count = 0
+      local ecs_index = -1
+      local first_cookie_index = -1
+      local last_cookie_index = -1
+      if options_count > 0 then
+        for idx = 0, options_count-1 do
+          if options_ptr[0][idx].optionCode == EDNSOptionCode.COOKIE then
+            cookies_count = cookies_count + 1
+            if first_cookie_index == -1 then
+              first_cookie_index = idx
+            end
+            last_cookie_index = idx
+          elseif options_ptr[0][idx].optionCode == EDNSOptionCode.ECS then
+            ecs_count = ecs_count + 1
+            ecs_index = idx
+          end
+        end
+      end
+
+      if string.match(qname, 'multiplecookies') then
+        if cookies_count == 0 then
+          local str = "192.0.2.1"
+          local buf = ffi.new("char[?]", #str + 1)
+          ffi.copy(buf, str)
+          ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+          return DNSAction.Spoof
+        end
+        if cookies_count ~= 2 then
+          local str = "192.0.2.2"
+          local buf = ffi.new("char[?]", #str + 1)
+          ffi.copy(buf, str)
+          ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+          return DNSAction.Spoof
+        end
+        if options_ptr[0][first_cookie_index].len ~= 16 then
+          local str = "192.0.2.3"
+          local buf = ffi.new("char[?]", #str + 1)
+          ffi.copy(buf, str)
+          ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+          return DNSAction.Spoof
+        end
+        if options_ptr[0][last_cookie_index].len ~= 16 then
+          local str = "192.0.2.4"
+          local buf = ffi.new("char[?]", #str + 1)
+          ffi.copy(buf, str)
+          ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+          return DNSAction.Spoof
+        end
+      elseif string.match(qname, 'cookie') then
+
+        if cookies_count == 0 then
+          local str = "192.0.2.1"
+          local buf = ffi.new("char[?]", #str + 1)
+          ffi.copy(buf, str)
+          ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+          return DNSAction.Spoof
+        end
+        if cookies_count ~= 1 or options_ptr[0][first_cookie_index].len ~= 16 then
+          local str = "192.0.2.2"
+          local buf = ffi.new("char[?]", #str + 1)
+          ffi.copy(buf, str)
+          ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+          return DNSAction.Spoof
+        end
+      end
+
+      if string.match(qname, 'ecs4') then
+        if ecs_count == 0 then
+          local str = "192.0.2.51"
+          local buf = ffi.new("char[?]", #str + 1)
+          ffi.copy(buf, str)
+          ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+          return DNSAction.Spoof
+        end
+
+        if ecs_count ~= 1 or options_ptr[0][ecs_index].len ~= 8 then
+          local str = "192.0.2.52"
+          local buf = ffi.new("char[?]", #str + 1)
+          ffi.copy(buf, str)
+          ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+          return DNSAction.Spoof
+        end
+      end
+
+      if string.match(qname, 'ecs6') then
+        if ecs_count == 0 then
+          local str = "192.0.2.101"
+          local buf = ffi.new("char[?]", #str + 1)
+          ffi.copy(buf, str)
+          ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+          return DNSAction.Spoof
+        end
+        if ecs_count ~= 1 or options_ptr[0][ecs_index].len ~= 20 then
+          local str = "192.0.2.102"
+          local buf = ffi.new("char[?]", #str + 1)
+          ffi.copy(buf, str)
+          ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+          return DNSAction.Spoof
+        end
+      end
+
+      return DNSAction.None
+
+    end
+
+    addAction(AllRule(), LuaFFIAction(testEDNSOptions))
+
+    newServer{address="127.0.0.1:%s"}
+    """
+
+    def testWithoutEDNSFFI(self):
+        """
+        EDNS Options: No EDNS (FFI)
+        """
+        name = 'noedns.ednsoptions.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.255')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
+
+    def testCookieFFI(self):
+        """
+        EDNS Options: Cookie (FFI)
+        """
+        name = 'cookie.ednsoptions.tests.powerdns.com.'
+        eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco])
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
+
+    def testECS4FFI(self):
+        """
+        EDNS Options: ECS4 (FFI)
+        """
+        name = 'ecs4.ednsoptions.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('1.2.3.4', 32)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
+
+    def testECS6FFI(self):
+        """
+        EDNS Options: ECS6 (FFI)
+        """
+        name = 'ecs6.ednsoptions.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
+
+    def testECS6CookieFFI(self):
+        """
+        EDNS Options: Cookie + ECS6 (FFI)
+        """
+        name = 'cookie-ecs6.ednsoptions.tests.powerdns.com.'
+        eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
+        ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso,eco])
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
+
+    def testMultiCookiesECS6FFI(self):
+        """
+        EDNS Options: Two Cookies + ECS6 (FFI)
+        """
+        name = 'multiplecookies-ecs6.ednsoptions.tests.powerdns.com.'
+        eco1 = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
+        ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
+        eco2 = cookiesoption.CookiesOption(b'deadc0de', b'deadc0de')
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco1, ecso, eco2])
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)