}
auto states = g_dstates.getLocal();
for(const auto& state : *states) {
- string serverName = state->name.empty() ? (state->remote.toString() + ":" + std::to_string(state->remote.getPort())) : state->getName();
+ string serverName = state->getName().empty() ? (state->remote.toString() + ":" + std::to_string(state->remote.getPort())) : state->getName();
boost::replace_all(serverName, ".", "_");
const string base = namespace_name + "." + hostname + "." + instance_name + ".servers." + serverName + ".";
str<<base<<"queries" << ' ' << state->queries.load() << " " << now << "\r\n";
{ "LogAction", true, "[filename], [binary], [append], [buffered]", "Log a line for each query, to the specified file if any, to the console (require verbose) otherwise. When logging to a file, the `binary` optional parameter specifies whether we log in binary form (default) or in textual form, the `append` optional parameter specifies whether we open the file for appending or truncate each time (default), and the `buffered` optional parameter specifies whether writes to the file are buffered (default) or not." },
{ "LogResponseAction", true, "[filename], [append], [buffered]", "Log a line for each response, to the specified file if any, to the console (require verbose) otherwise. The `append` optional parameter specifies whether we open the file for appending or truncate each time (default), and the `buffered` optional parameter specifies whether writes to the file are buffered (default) or not." },
{ "LuaAction", true, "function", "Invoke a Lua function that accepts a DNSQuestion" },
+ { "LuaFFIAction", true, "function", "Invoke a Lua FFI function that accepts a DNSQuestion" },
+ { "LuaFFIResponseAction", true, "function", "Invoke a Lua FFI function that accepts a DNSResponse" },
+ { "LuaFFIRule", true, "function", "Invoke a Lua FFI function that filters DNS questions" },
{ "LuaResponseAction", true, "function", "Invoke a Lua function that accepts a DNSResponse" },
+ { "LuaRule", true, "function", "Invoke a Lua function that filters DNS questions" },
{ "MacAddrAction", true, "option", "Add the source MAC address to the query as EDNS0 option option. This action is currently only supported on Linux. Subsequent rules are processed after this action" },
{ "makeIPCipherKey", true, "password", "generates a 16-byte key that can be used to pseudonymize IP addresses with IP cipher" },
{ "makeKey", true, "", "generate a new server access key, emit configuration line ready for pasting" },
{ "setSecurityPollSuffix", true, "suffix", "set the security polling suffix to the specified value" },
{ "setServerPolicy", true, "policy", "set server selection policy to that policy" },
{ "setServerPolicyLua", true, "name, function", "set server selection policy to one named 'name' and provided by 'function'" },
+ { "setServerPolicyLuaFFI", true, "name, function", "set server selection policy to one named 'name' and provided by the Lua FFI 'function'" },
{ "setServFailWhenNoServer", true, "bool", "if set, return a ServFail when no servers are available, instead of the default behaviour of dropping the query" },
{ "setStaleCacheEntriesTTL", true, "n", "allows using cache entries expired for at most n seconds when there is no backend available to answer for a query" },
{ "setSyslogFacility", true, "facility", "set the syslog logging facility to 'facility'. Defaults to LOG_DAEMON" },
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+struct dnsdist_ffi_servers_list_t;
+struct dnsdist_ffi_server_t;
+struct dnsdist_ffi_dnsquestion_t;
+
+struct DownstreamState;
+
+struct ServerPolicy
+{
+ template <class T> using NumberedVector = std::vector<std::pair<unsigned int, T> >;
+ using NumberedServerVector = NumberedVector<shared_ptr<DownstreamState>>;
+ typedef std::function<shared_ptr<DownstreamState>(const NumberedServerVector& servers, const DNSQuestion*)> policyfunc_t;
+ typedef std::function<unsigned int(dnsdist_ffi_servers_list_t* servers, dnsdist_ffi_dnsquestion_t* dq)> ffipolicyfunc_t;
+
+ ServerPolicy(const std::string& name_, policyfunc_t policy_, bool isLua_): name(name_), policy(policy_), isLua(isLua_)
+ {
+ }
+ ServerPolicy(const std::string& name_, ffipolicyfunc_t policy_): name(name_), ffipolicy(policy_), isLua(true), isFFI(true)
+ {
+ }
+ ServerPolicy()
+ {
+ }
+
+ string name;
+ policyfunc_t policy;
+ ffipolicyfunc_t ffipolicy;
+ bool isLua{false};
+ bool isFFI{false};
+
+ std::string toString() const {
+ return string("ServerPolicy") + (isLua ? " (Lua)" : "") + " \"" + name + "\"";
+ }
+};
+
+struct ServerPool;
+
+using pools_t=map<std::string,std::shared_ptr<ServerPool>>;
+std::shared_ptr<ServerPool> getPool(const pools_t& pools, const std::string& poolName);
+std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string& poolName);
+void setPoolPolicy(pools_t& pools, const string& poolName, std::shared_ptr<ServerPolicy> policy);
+void addServerToPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server);
+void removeServerFromPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server);
+
+ServerPolicy::NumberedServerVector getDownstreamCandidates(const map<std::string,std::shared_ptr<ServerPool>>& pools, const std::string& poolName);
+
+std::shared_ptr<DownstreamState> firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+
+std::shared_ptr<DownstreamState> leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::shared_ptr<DownstreamState> wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::shared_ptr<DownstreamState> whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::shared_ptr<DownstreamState> whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash);
+std::shared_ptr<DownstreamState> chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::shared_ptr<DownstreamState> chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash);
+std::shared_ptr<DownstreamState> roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::shared_ptr<DownstreamState> getSelectedBackendFromPolicy(const ServerPolicy& policy, const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq);
#include "dnsdist.hh"
#include "dnsdist-ecs.hh"
#include "dnsdist-lua.hh"
+#include "dnsdist-lua-ffi.hh"
#include "dnsdist-protobuf.hh"
#include "dnsdist-kvs.hh"
}
};
-DNSAction::Action LuaAction::operator()(DNSQuestion* dq, std::string* ruleresult) const
+class LuaAction : public DNSAction
{
- std::lock_guard<std::mutex> lock(g_luamutex);
- try {
- auto ret = d_func(dq);
- if (ruleresult) {
- if (boost::optional<std::string> rule = std::get<1>(ret)) {
- *ruleresult = *rule;
+public:
+ typedef std::function<std::tuple<int, boost::optional<string> >(DNSQuestion* dq)> func_t;
+ LuaAction(const LuaAction::func_t& func) : d_func(func)
+ {}
+
+ DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
+ {
+ std::lock_guard<std::mutex> lock(g_luamutex);
+ try {
+ auto ret = d_func(dq);
+ if (ruleresult) {
+ if (boost::optional<std::string> rule = std::get<1>(ret)) {
+ *ruleresult = *rule;
+ }
+ else {
+ // default to empty string
+ ruleresult->clear();
+ }
}
- else {
- // default to empty string
- ruleresult->clear();
+ return static_cast<Action>(std::get<0>(ret));
+ } catch (const std::exception &e) {
+ warnlog("LuaAction failed inside Lua, returning ServFail: %s", e.what());
+ } catch (...) {
+ warnlog("LuaAction failed inside Lua, returning ServFail: [unknown exception]");
+ }
+ return DNSAction::Action::ServFail;
+ }
+
+ string toString() const override
+ {
+ return "Lua script";
+ }
+private:
+ func_t d_func;
+};
+
+class LuaResponseAction : public DNSResponseAction
+{
+public:
+ typedef std::function<std::tuple<int, boost::optional<string> >(DNSResponse* dr)> func_t;
+ LuaResponseAction(const LuaResponseAction::func_t& func) : d_func(func)
+ {}
+ DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override
+ {
+ std::lock_guard<std::mutex> lock(g_luamutex);
+ try {
+ auto ret = d_func(dr);
+ if(ruleresult) {
+ if (boost::optional<std::string> rule = std::get<1>(ret)) {
+ *ruleresult = *rule;
+ }
+ else {
+ // default to empty string
+ ruleresult->clear();
+ }
}
+ return static_cast<Action>(std::get<0>(ret));
+ } catch (const std::exception &e) {
+ warnlog("LuaResponseAction failed inside Lua, returning ServFail: %s", e.what());
+ } catch (...) {
+ warnlog("LuaResponseAction failed inside Lua, returning ServFail: [unknown exception]");
}
- return (Action)std::get<0>(ret);
- } catch (std::exception &e) {
- warnlog("LuaAction failed inside lua, returning ServFail: %s", e.what());
- } catch (...) {
- warnlog("LuaAction failed inside lua, returning ServFail: [unknown exception]");
+ return DNSResponseAction::Action::ServFail;
}
- return DNSAction::Action::ServFail;
-}
-DNSResponseAction::Action LuaResponseAction::operator()(DNSResponse* dr, std::string* ruleresult) const
+ string toString() const override
+ {
+ return "Lua response script";
+ }
+private:
+ func_t d_func;
+};
+
+class LuaFFIAction: public DNSAction
{
- std::lock_guard<std::mutex> lock(g_luamutex);
- try {
- auto ret = d_func(dr);
- if(ruleresult) {
- if (boost::optional<std::string> rule = std::get<1>(ret)) {
- *ruleresult = *rule;
+public:
+ typedef std::function<int(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+
+ LuaFFIAction(const LuaFFIAction::func_t& func): d_func(func)
+ {
+ }
+
+ DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
+ {
+ dnsdist_ffi_dnsquestion_t dqffi(dq);
+ try {
+ std::lock_guard<std::mutex> lock(g_luamutex);
+
+ auto ret = d_func(&dqffi);
+ if (ruleresult) {
+ if (dqffi.result) {
+ *ruleresult = *dqffi.result;
+ }
+ else {
+ // default to empty string
+ ruleresult->clear();
+ }
}
- else {
- // default to empty string
- ruleresult->clear();
+ return static_cast<DNSAction::Action>(ret);
+ } catch (const std::exception &e) {
+ warnlog("LuaFFIAction failed inside Lua, returning ServFail: %s", e.what());
+ } catch (...) {
+ warnlog("LuaFFIAction failed inside Lua, returning ServFail: [unknown exception]");
+ }
+ return DNSAction::Action::ServFail;
+ }
+
+ string toString() const override
+ {
+ return "Lua FFI script";
+ }
+private:
+ func_t d_func;
+};
+
+
+class LuaFFIResponseAction: public DNSResponseAction
+{
+public:
+ typedef std::function<int(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+
+ LuaFFIResponseAction(const LuaFFIResponseAction::func_t& func): d_func(func)
+ {
+ }
+
+ DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override
+ {
+ DNSQuestion* dq = dynamic_cast<DNSQuestion*>(dr);
+ if (dq == nullptr) {
+ return DNSResponseAction::Action::ServFail;
+ }
+
+ dnsdist_ffi_dnsquestion_t dqffi(dq);
+ try {
+ std::lock_guard<std::mutex> lock(g_luamutex);
+
+ auto ret = d_func(&dqffi);
+ if (ruleresult) {
+ if (dqffi.result) {
+ *ruleresult = *dqffi.result;
+ }
+ else {
+ // default to empty string
+ ruleresult->clear();
+ }
}
+ return static_cast<DNSResponseAction::Action>(ret);
+ } catch (const std::exception &e) {
+ warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: %s", e.what());
+ } catch (...) {
+ warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: [unknown exception]");
}
- return (Action)std::get<0>(ret);
- } catch (std::exception &e) {
- warnlog("LuaResponseAction failed inside lua, returning ServFail: %s", e.what());
- } catch (...) {
- warnlog("LuaResponseAction failed inside lua, returning ServFail: [unknown exception]");
+ return DNSResponseAction::Action::ServFail;
}
- return DNSResponseAction::Action::ServFail;
-}
+
+ string toString() const override
+ {
+ return "Lua FFI script";
+ }
+private:
+ func_t d_func;
+};
DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresult) const
{
return std::shared_ptr<DNSAction>(new LuaAction(func));
});
+ g_lua.writeFunction("LuaFFIAction", [](LuaFFIAction::func_t func) {
+ setLuaSideEffect();
+ return std::shared_ptr<DNSAction>(new LuaFFIAction(func));
+ });
+
g_lua.writeFunction("NoRecurseAction", []() {
return std::shared_ptr<DNSAction>(new NoRecurseAction);
});
return std::shared_ptr<DNSResponseAction>(new LuaResponseAction(func));
});
+ g_lua.writeFunction("LuaFFIResponseAction", [](LuaFFIResponseAction::func_t func) {
+ setLuaSideEffect();
+ return std::shared_ptr<DNSResponseAction>(new LuaFFIResponseAction(func));
+ });
+
g_lua.writeFunction("RemoteLogAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<std::unordered_map<std::string, std::string>> vars) {
if (logger) {
// avoids potentially-evaluated-expression warning with clang.
return *dq.ednsOptions;
});
g_lua.registerFunction<std::string(DNSQuestion::*)(void)>("getTrailingData", [](const DNSQuestion& dq) {
- const char* message = reinterpret_cast<const char*>(dq.dh);
- const uint16_t messageLen = getDNSPacketLength(message, dq.len);
- const std::string tail = std::string(message + messageLen, dq.len - messageLen);
- return tail;
+ return dq.getTrailingData();
});
g_lua.registerFunction<bool(DNSQuestion::*)(std::string)>("setTrailingData", [](DNSQuestion& dq, const std::string& tail) {
- char* message = reinterpret_cast<char*>(dq.dh);
- const uint16_t messageLen = getDNSPacketLength(message, dq.len);
- const uint16_t tailLen = tail.size();
- if(tailLen > (dq.size - messageLen)) {
- return false;
- }
-
- /* Update length and copy data from the Lua string. */
- dq.len = messageLen + tailLen;
- if(tailLen > 0) {
- tail.copy(message + messageLen, tailLen);
- }
- return true;
+ return dq.setTrailingData(tail);
});
g_lua.registerFunction<std::string(DNSQuestion::*)()>("getServerNameIndication", [](const DNSQuestion& dq) {
editDNSPacketTTL((char*) dr.dh, dr.len, editFunc);
});
g_lua.registerFunction<std::string(DNSResponse::*)(void)>("getTrailingData", [](const DNSResponse& dq) {
- const char* message = reinterpret_cast<const char*>(dq.dh);
- const uint16_t messageLen = getDNSPacketLength(message, dq.len);
- const std::string tail = std::string(message + messageLen, dq.len - messageLen);
- return tail;
+ return dq.getTrailingData();
});
g_lua.registerFunction<bool(DNSResponse::*)(std::string)>("setTrailingData", [](DNSResponse& dq, const std::string& tail) {
- char* message = reinterpret_cast<char*>(dq.dh);
- const uint16_t messageLen = getDNSPacketLength(message, dq.len);
- const uint16_t tailLen = tail.size();
- if(tailLen > (dq.size - messageLen)) {
- return false;
- }
-
- /* Update length and copy data from the Lua string. */
- dq.len = messageLen + tailLen;
- if(tailLen > 0) {
- tail.copy(message + messageLen, tailLen);
- }
- return true;
+ return dq.setTrailingData(tail);
});
g_lua.registerFunction<void(DNSResponse::*)(std::string, std::string)>("setTag", [](DNSResponse& dr, const std::string& strLabel, const std::string& strValue) {
return string("No exception");
});
/* ServerPolicy */
- g_lua.writeFunction("newServerPolicy", [](string name, policyfunc_t policy) { return ServerPolicy{name, policy, true};});
+ g_lua.writeFunction("newServerPolicy", [](string name, ServerPolicy::policyfunc_t policy) { return std::make_shared<ServerPolicy>(name, policy, true);});
g_lua.registerMember("name", &ServerPolicy::name);
g_lua.registerMember("policy", &ServerPolicy::policy);
+ g_lua.registerMember("ffipolicy", &ServerPolicy::ffipolicy);
g_lua.registerMember("isLua", &ServerPolicy::isLua);
+ g_lua.registerMember("isFFI", &ServerPolicy::isFFI);
g_lua.registerFunction("toString", &ServerPolicy::toString);
g_lua.writeVariable("firstAvailable", ServerPolicy{"firstAvailable", firstAvailable, false});
[](DownstreamState& s, int newWeight) {s.setWeight(newWeight);}
);
g_lua.registerMember("order", &DownstreamState::order);
- g_lua.registerMember("name", &DownstreamState::name);
+ g_lua.registerMember<const std::string(DownstreamState::*)>("name", [](const DownstreamState& backend) -> const std::string { return backend.getName(); }, [](DownstreamState& backend, const std::string& newName) { backend.setName(newName); });
g_lua.registerFunction<std::string(DownstreamState::*)()>("getID", [](const DownstreamState& s) { return boost::uuids::to_string(s.id); });
/* dnsheader */
/* ComboAddress */
g_lua.writeFunction("newCA", [](const std::string& name) { return ComboAddress(name); });
+ g_lua.writeFunction("newCAFromRaw", [](const std::string& raw, boost::optional<uint16_t> port) {
+ if (raw.size() == 4) {
+ struct sockaddr_in sin4;
+ memset(&sin4, 0, sizeof(sin4));
+ sin4.sin_family = AF_INET;
+ memcpy(&sin4.sin_addr.s_addr, raw.c_str(), raw.size());
+ if (port) {
+ sin4.sin_port = htons(*port);
+ }
+ return ComboAddress(&sin4);
+ }
+ else if (raw.size() == 16) {
+ struct sockaddr_in6 sin6;
+ memset(&sin6, 0, sizeof(sin6));
+ sin6.sin6_family = AF_INET6;
+ memcpy(&sin6.sin6_addr.s6_addr, raw.c_str(), raw.size());
+ if (port) {
+ sin6.sin6_port = htons(*port);
+ }
+ return ComboAddress(&sin6);
+ }
+ return ComboAddress();
+ });
g_lua.registerFunction<string(ComboAddress::*)()>("tostring", [](const ComboAddress& ca) { return ca.toString(); });
g_lua.registerFunction<string(ComboAddress::*)()>("tostringWithPort", [](const ComboAddress& ca) { return ca.toStringWithPort(); });
g_lua.registerFunction<string(ComboAddress::*)()>("toString", [](const ComboAddress& ca) { return ca.toString(); });
g_lua.registerFunction("isPartOf", &DNSName::isPartOf);
g_lua.registerFunction<bool(DNSName::*)()>("chopOff", [](DNSName&dn ) { return dn.chopOff(); });
g_lua.registerFunction<unsigned int(DNSName::*)()>("countLabels", [](const DNSName& name) { return name.countLabels(); });
+ g_lua.registerFunction<size_t(DNSName::*)()>("hash", [](const DNSName& name) { return name.hash(); });
g_lua.registerFunction<size_t(DNSName::*)()>("wirelength", [](const DNSName& name) { return name.wirelength(); });
g_lua.registerFunction<string(DNSName::*)()>("tostring", [](const DNSName&dn ) { return dn.toString(); });
g_lua.registerFunction<string(DNSName::*)()>("toString", [](const DNSName&dn ) { return dn.toString(); });
g_lua.writeFunction("newDNSName", [](const std::string& name) { return DNSName(name); });
+ g_lua.writeFunction("newDNSNameFromRaw", [](const std::string& name) { return DNSName(name.c_str(), name.size(), 0, false); });
g_lua.writeFunction("newSuffixMatchNode", []() { return SuffixMatchNode(); });
g_lua.writeFunction("newDNSNameSet", []() { return DNSNameSet(); });
});
g_lua.registerFunction("check",(bool (SuffixMatchNode::*)(const DNSName&) const) &SuffixMatchNode::check);
+ /* Netmask */
+ g_lua.writeFunction("newNetmask", [](boost::variant<std::string,ComboAddress> s, boost::optional<uint8_t> bits) {
+ if (s.type() == typeid(ComboAddress)) {
+ auto ca = boost::get<ComboAddress>(s);
+ if (bits) {
+ return Netmask(ca, *bits);
+ }
+ return Netmask(ca);
+ }
+ else if (s.type() == typeid(std::string)) {
+ auto str = boost::get<std::string>(s);
+ return Netmask(str);
+ }
+ throw std::runtime_error("Invalid parameter passed to 'newNetmask()'");
+ });
+ g_lua.registerFunction("empty", &Netmask::empty);
+ g_lua.registerFunction("getBits", &Netmask::getBits);
+ g_lua.registerFunction<ComboAddress(Netmask::*)()>("getNetwork", [](const Netmask& nm) { return nm.getNetwork(); } ); // const reference makes this necessary
+ g_lua.registerFunction<ComboAddress(Netmask::*)()>("getMaskedNetwork", [](const Netmask& nm) { return nm.getMaskedNetwork(); } );
+ g_lua.registerFunction("isIpv4", &Netmask::isIPv4);
+ g_lua.registerFunction("isIPv4", &Netmask::isIPv4);
+ g_lua.registerFunction("isIpv6", &Netmask::isIPv6);
+ g_lua.registerFunction("isIPv6", &Netmask::isIPv6);
+ g_lua.registerFunction("match", (bool (Netmask::*)(const string&) const)&Netmask::match);
+ g_lua.registerFunction("toString", &Netmask::toString);
+ g_lua.registerEqFunction(&Netmask::operator==);
+ g_lua.registerToStringFunction(&Netmask::toString);
+
/* NetmaskGroup */
g_lua.writeFunction("newNMG", []() { return NetmaskGroup(); });
g_lua.registerFunction<void(NetmaskGroup::*)(const std::string&mask)>("addMask", [](NetmaskGroup&nmg, const std::string& mask)
auto states = g_dstates.getLocal();
counter = 0;
for(const auto& s : *states) {
- ret << (fmt % counter % s->name % s->remote.toStringWithPort() % s->tcpCurrentConnections % s->tcpDiedSendingQuery % s->tcpDiedReadingResponse % s->tcpGaveUp % s->tcpReadTimeouts % s->tcpWriteTimeouts % s->tcpAvgQueriesPerConnection % s->tcpAvgConnectionDuration) << endl;
+ ret << (fmt % counter % s->getName() % s->remote.toStringWithPort() % s->tcpCurrentConnections % s->tcpDiedSendingQuery % s->tcpDiedReadingResponse % s->tcpGaveUp % s->tcpReadTimeouts % s->tcpWriteTimeouts % s->tcpAvgQueriesPerConnection % s->tcpAvgConnectionDuration) << endl;
++counter;
}
g_lua.writeFunction("KeyValueStoreLookupRule", [](std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey) {
return std::shared_ptr<DNSRule>(new KeyValueStoreLookupRule(kvs, lookupKey));
});
+
+ g_lua.writeFunction("LuaRule", [](LuaRule::func_t func) {
+ return std::shared_ptr<DNSRule>(new LuaRule(func));
+ });
+
+ g_lua.writeFunction("LuaFFIRule", [](LuaFFIRule::func_t func) {
+ return std::shared_ptr<DNSRule>(new LuaFFIRule(func));
+ });
}
#include "dnsdist-ecs.hh"
#include "dnsdist-healthchecks.hh"
#include "dnsdist-lua.hh"
+#ifdef LUAJIT_VERSION
+#include "dnsdist-lua-ffi.hh"
+#endif /* LUAJIT_VERSION */
#include "dnsdist-rings.hh"
#include "dnsdist-secpoll.hh"
#endif // defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
-void setupLuaConfig(bool client, bool configCheck)
+static void setupLuaConfig(bool client, bool configCheck)
{
typedef std::unordered_map<std::string, boost::variant<bool, std::string, vector<pair<int, std::string> >, DownstreamState::checkfunc_t > > newserver_t;
g_lua.writeFunction("inClientStartup", [client]() {
// create but don't connect the socket in client or check-config modes
ret=std::make_shared<DownstreamState>(serverAddr, sourceAddr, sourceItf, sourceItfName, numberOfSockets, !(client || configCheck));
+ if (!(client || configCheck)) {
+ infolog("Added downstream server %s", serverAddr.toStringWithPort());
+ }
if(vars.count("qps")) {
int qpsVal=std::stoi(boost::get<string>(vars["qps"]));
}
if(vars.count("name")) {
- ret->name=boost::get<string>(vars["name"]);
+ ret->setName(boost::get<string>(vars["name"]));
}
if (vars.count("id")) {
g_dstates.setState(states);
} );
- g_lua.writeFunction("setServerPolicy", [](ServerPolicy policy) {
- setLuaSideEffect();
- g_policy.setState(policy);
- });
- g_lua.writeFunction("setServerPolicyLua", [](string name, policyfunc_t policy) {
- setLuaSideEffect();
- g_policy.setState(ServerPolicy{name, policy, true});
- });
-
- g_lua.writeFunction("showServerPolicy", []() {
- setLuaSideEffect();
- g_outputBuffer=g_policy.getLocal()->name+"\n";
- });
-
g_lua.writeFunction("truncateTC", [](bool tc) { setLuaSideEffect(); g_truncateTC=tc; });
g_lua.writeFunction("fixupCase", [](bool fu) { setLuaSideEffect(); g_fixupCase=fu; });
pools+=p;
}
if (showUUIDs) {
- ret << (fmt % counter % s->name % s->remote.toStringWithPort() %
+ ret << (fmt % counter % s->getName() % s->remote.toStringWithPort() %
status %
s->queryLoad % s->qps.getRate() % s->order % s->weight % s->queries.load() % s->reuseds.load() % (s->dropRate) % (s->latencyUsec/1000.0) % s->outstanding.load() % pools % s->id) << endl;
} else {
- ret << (fmt % counter % s->name % s->remote.toStringWithPort() %
+ ret << (fmt % counter % s->getName() % s->remote.toStringWithPort() %
status %
s->queryLoad % s->qps.getRate() % s->order % s->weight % s->queries.load() % s->reuseds.load() % (s->dropRate) % (s->latencyUsec/1000.0) % s->outstanding.load() % pools) << endl;
}
if (!servers.empty()) {
servers += ", ";
}
- if (!server.second->name.empty()) {
- servers += server.second->name;
+ if (!server.second->getName().empty()) {
+ servers += server.second->getName();
servers += " ";
}
servers += server.second->remote.toStringWithPort();
#endif /* HAVE_NET_SNMP */
});
+ g_lua.writeFunction("setServerPolicy", [](ServerPolicy policy) {
+ setLuaSideEffect();
+ g_policy.setState(policy);
+ });
+
+ g_lua.writeFunction("setServerPolicyLua", [](string name, ServerPolicy::policyfunc_t policy) {
+ setLuaSideEffect();
+ g_policy.setState(ServerPolicy{name, policy, true});
+ });
+
+ g_lua.writeFunction("setServerPolicyLuaFFI", [](string name, ServerPolicy::ffipolicyfunc_t policy) {
+ setLuaSideEffect();
+ auto pol = ServerPolicy(name, policy);
+ g_policy.setState(std::move(pol));
+ });
+
+ g_lua.writeFunction("showServerPolicy", []() {
+ setLuaSideEffect();
+ g_outputBuffer=g_policy.getLocal()->name+"\n";
+ });
+
g_lua.writeFunction("setPoolServerPolicy", [](ServerPolicy policy, string pool) {
setLuaSideEffect();
auto localPools = g_pools.getCopy();
g_pools.setState(localPools);
});
- g_lua.writeFunction("setPoolServerPolicyLua", [](string name, policyfunc_t policy, string pool) {
+ g_lua.writeFunction("setPoolServerPolicyLua", [](string name, ServerPolicy::policyfunc_t policy, string pool) {
setLuaSideEffect();
auto localPools = g_pools.getCopy();
setPoolPolicy(localPools, pool, std::make_shared<ServerPolicy>(ServerPolicy{name, policy, true}));
setupLuaRules();
setupLuaVars();
+#ifdef LUAJIT_VERSION
+ g_lua.executeCode(getLuaFFIWrappers());
+#endif
+
std::ifstream ifs(config);
if(!ifs)
warnlog("Unable to read configuration from '%s'", config);
};
void setResponseHeadersFromConfig(dnsheader& dh, const ResponseConfig& config);
-class LuaAction : public DNSAction
-{
-public:
- typedef std::function<std::tuple<int, boost::optional<string> >(DNSQuestion* dq)> func_t;
- LuaAction(const LuaAction::func_t& func) : d_func(func)
- {}
- Action operator()(DNSQuestion* dq, string* ruleresult) const override;
- string toString() const override
- {
- return "Lua script";
- }
-private:
- func_t d_func;
-};
-
-class LuaResponseAction : public DNSResponseAction
-{
-public:
- typedef std::function<std::tuple<int, boost::optional<string> >(DNSResponse* dr)> func_t;
- LuaResponseAction(const LuaResponseAction::func_t& func) : d_func(func)
- {}
- Action operator()(DNSResponse* dr, string* ruleresult) const override;
- string toString() const override
- {
- return "Lua response script";
- }
-private:
- func_t d_func;
-};
-
class SpoofAction : public DNSAction
{
public:
static std::unordered_map<oid, DNSDistStats::entry_t> s_statsMap;
+bool g_snmpEnabled{false};
+bool g_snmpTrapsEnabled{false};
+DNSDistSNMPAgent* g_snmpAgent{nullptr};
+
/* We are never called for a GETNEXT if it's registered as a
"instance", as it's "magically" handled for us. */
/* a instance handler also only hands us one request at a time, so
case COLUMN_BACKENDNAME:
snmp_set_var_typed_value(request->requestvb,
ASN_OCTET_STR,
- server->name.c_str(),
- server->name.size());
+ server->getName().c_str(),
+ server->getName().size());
break;
case COLUMN_BACKENDLATENCY:
DNSDistSNMPAgent::setCounter64Value(request,
backendNameOID,
OID_LENGTH(backendNameOID),
ASN_OCTET_STR,
- dss->name.c_str(),
- dss->name.size());
+ dss->getName().c_str(),
+ dss->getName().size());
snmp_varlist_add_variable(&varList,
backendAddressOID,
for (const auto& state : *states) {
string serverName;
- if (state->name.empty())
+ if (state->getName().empty())
serverName = state->remote.toStringWithPort();
else
serverName = state->getName();
Json::object server{
{"id", num++},
- {"name", a->name},
+ {"name", a->getName()},
{"address", a->remote.toStringWithPort()},
{"state", status},
{"qps", (double)a->queryLoad},
GlobalStateHolder<pools_t> g_pools;
size_t g_udpVectorSize{1};
-bool g_snmpEnabled{false};
-bool g_snmpTrapsEnabled{false};
-DNSDistSNMPAgent* g_snmpAgent{nullptr};
-
/* UDP: the grand design. Per socket we listen on for incoming queries there is one thread.
Then we have a bunch of connected sockets for talking to downstream servers.
We send directly to those sockets.
bool g_truncateTC{false};
bool g_fixupCase{false};
bool g_preserveTrailingData{false};
-bool g_roundrobinFailOnNoServer{false};
std::set<std::string> g_capabilitiesToRetain;
DelayPipe<DelayedPacket>* g_delay = nullptr;
+std::string DNSQuestion::getTrailingData() const
+{
+ const char* message = reinterpret_cast<const char*>(this->dh);
+ const uint16_t messageLen = getDNSPacketLength(message, this->len);
+ return std::string(message + messageLen, this->len - messageLen);
+}
+
+bool DNSQuestion::setTrailingData(const std::string& tail)
+{
+ char* message = reinterpret_cast<char*>(this->dh);
+ const uint16_t messageLen = getDNSPacketLength(message, this->len);
+ const uint16_t tailLen = tail.size();
+ if (tailLen > (this->size - messageLen)) {
+ return false;
+ }
+
+ /* Update length and copy data from the Lua string. */
+ this->len = messageLen + tailLen;
+ if(tailLen > 0) {
+ tail.copy(message + messageLen, tailLen);
+ }
+ return true;
+}
+
void doLatencyStats(double udiff)
{
if(udiff < 1000) ++g_stats.latency0_1;
errlog("UDP responder thread died because of an exception: %s", "unknown");
}
-bool DownstreamState::reconnect()
-{
- std::unique_lock<std::mutex> tl(connectLock, std::try_to_lock);
- if (!tl.owns_lock()) {
- /* we are already reconnecting */
- return false;
- }
-
- connected = false;
- for (auto& fd : sockets) {
- if (fd != -1) {
- if (sockets.size() > 1) {
- std::lock_guard<std::mutex> lock(socketsLock);
- mplexer->removeReadFD(fd);
- }
- /* shutdown() is needed to wake up recv() in the responderThread */
- shutdown(fd, SHUT_RDWR);
- close(fd);
- fd = -1;
- }
- if (!IsAnyAddress(remote)) {
- fd = SSocket(remote.sin4.sin_family, SOCK_DGRAM, 0);
- if (!IsAnyAddress(sourceAddr)) {
- SSetsockopt(fd, SOL_SOCKET, SO_REUSEADDR, 1);
- if (!sourceItfName.empty()) {
-#ifdef SO_BINDTODEVICE
- int res = setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, sourceItfName.c_str(), sourceItfName.length());
- if (res != 0) {
- infolog("Error setting up the interface on backend socket '%s': %s", remote.toStringWithPort(), stringerror());
- }
-#endif
- }
-
- SBind(fd, sourceAddr);
- }
- try {
- SConnect(fd, remote);
- if (sockets.size() > 1) {
- std::lock_guard<std::mutex> lock(socketsLock);
- mplexer->addReadFD(fd, [](int, boost::any) {});
- }
- connected = true;
- }
- catch(const std::runtime_error& error) {
- infolog("Error connecting to new server with address %s: %s", remote.toStringWithPort(), error.what());
- connected = false;
- break;
- }
- }
- }
-
- /* if at least one (re-)connection failed, close all sockets */
- if (!connected) {
- for (auto& fd : sockets) {
- if (fd != -1) {
- if (sockets.size() > 1) {
- std::lock_guard<std::mutex> lock(socketsLock);
- mplexer->removeReadFD(fd);
- }
- /* shutdown() is needed to wake up recv() in the responderThread */
- shutdown(fd, SHUT_RDWR);
- close(fd);
- fd = -1;
- }
- }
- }
-
- return connected;
-}
-void DownstreamState::hash()
-{
- vinfolog("Computing hashes for id=%s and weight=%d", id, weight);
- auto w = weight;
- WriteLock wl(&d_lock);
- hashes.clear();
- while (w > 0) {
- std::string uuid = boost::str(boost::format("%s-%d") % id % w);
- unsigned int wshash = burtleCI((const unsigned char*)uuid.c_str(), uuid.size(), g_hashperturb);
- hashes.insert(wshash);
- --w;
- }
-}
-
-void DownstreamState::setId(const boost::uuids::uuid& newId)
-{
- id = newId;
- // compute hashes only if already done
- if (!hashes.empty()) {
- hash();
- }
-}
-
-void DownstreamState::setWeight(int newWeight)
-{
- if (newWeight < 1) {
- errlog("Error setting server's weight: downstream weight value must be greater than 0.");
- return ;
- }
- weight = newWeight;
- if (!hashes.empty()) {
- hash();
- }
-}
-
-DownstreamState::DownstreamState(const ComboAddress& remote_, const ComboAddress& sourceAddr_, unsigned int sourceItf_, const std::string& sourceItfName_, size_t numberOfSockets, bool connect=true): sourceItfName(sourceItfName_), remote(remote_), sourceAddr(sourceAddr_), sourceItf(sourceItf_)
-{
- pthread_rwlock_init(&d_lock, nullptr);
- id = getUniqueID();
- threadStarted.clear();
-
- mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
-
- sockets.resize(numberOfSockets);
- for (auto& fd : sockets) {
- fd = -1;
- }
-
- if (connect && !IsAnyAddress(remote)) {
- reconnect();
- idStates.resize(g_maxOutstanding);
- sw.start();
- infolog("Added downstream server %s", remote.toStringWithPort());
- }
-
-}
-
std::mutex g_luamutex;
LuaContext g_lua;
-
-GlobalStateHolder<ServerPolicy> g_policy;
-
-shared_ptr<DownstreamState> firstAvailable(const NumberedServerVector& servers, const DNSQuestion* dq)
-{
- for(auto& d : servers) {
- if(d.second->isUp() && d.second->qps.check())
- return d.second;
- }
- return leastOutstanding(servers, dq);
-}
-
-// get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest
-shared_ptr<DownstreamState> leastOutstanding(const NumberedServerVector& servers, const DNSQuestion* dq)
-{
- if (servers.size() == 1 && servers[0].second->isUp()) {
- return servers[0].second;
- }
-
- vector<pair<tuple<int,int,double>, shared_ptr<DownstreamState>>> poss;
- /* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort,
- which would suck royally and could even lead to crashes. So first we snapshot on what we sort, and then we sort */
- poss.reserve(servers.size());
- for(auto& d : servers) {
- if(d.second->isUp()) {
- poss.push_back({make_tuple(d.second->outstanding.load(), d.second->order, d.second->latencyUsec), d.second});
- }
- }
- if(poss.empty())
- return shared_ptr<DownstreamState>();
- nth_element(poss.begin(), poss.begin(), poss.end(), [](const decltype(poss)::value_type& a, const decltype(poss)::value_type& b) { return a.first < b.first; });
- return poss.begin()->second;
-}
-
-shared_ptr<DownstreamState> valrandom(unsigned int val, const NumberedServerVector& servers, const DNSQuestion* dq)
-{
- vector<pair<int, shared_ptr<DownstreamState>>> poss;
- int sum = 0;
- int max = std::numeric_limits<int>::max();
-
- for(auto& d : servers) { // w=1, w=10 -> 1, 11
- if(d.second->isUp()) {
- // Don't overflow sum when adding high weights
- if(d.second->weight > max - sum) {
- sum = max;
- } else {
- sum += d.second->weight;
- }
-
- poss.push_back({sum, d.second});
- }
- }
-
- // Catch poss & sum are empty to avoid SIGFPE
- if(poss.empty())
- return shared_ptr<DownstreamState>();
-
- int r = val % sum;
- auto p = upper_bound(poss.begin(), poss.end(),r, [](int r_, const decltype(poss)::value_type& a) { return r_ < a.first;});
- if(p==poss.end())
- return shared_ptr<DownstreamState>();
- return p->second;
-}
-
-shared_ptr<DownstreamState> wrandom(const NumberedServerVector& servers, const DNSQuestion* dq)
-{
- return valrandom(random(), servers, dq);
-}
-
-uint32_t g_hashperturb;
-double g_consistentHashBalancingFactor = 0;
-shared_ptr<DownstreamState> whashed(const NumberedServerVector& servers, const DNSQuestion* dq)
-{
- return valrandom(dq->qname->hash(g_hashperturb), servers, dq);
-}
-
-shared_ptr<DownstreamState> chashed(const NumberedServerVector& servers, const DNSQuestion* dq)
-{
- unsigned int qhash = dq->qname->hash(g_hashperturb);
- unsigned int sel = std::numeric_limits<unsigned int>::max();
- unsigned int min = std::numeric_limits<unsigned int>::max();
- shared_ptr<DownstreamState> ret = nullptr, first = nullptr;
-
- double targetLoad = std::numeric_limits<double>::max();
- if (g_consistentHashBalancingFactor > 0) {
- /* we start with one, representing the query we are currently handling */
- double currentLoad = 1;
- for (const auto& pair : servers) {
- currentLoad += pair.second->outstanding;
- }
- targetLoad = (currentLoad / servers.size()) * g_consistentHashBalancingFactor;
- }
-
- for (const auto& d: servers) {
- if (d.second->isUp() && d.second->outstanding <= targetLoad) {
- // make sure hashes have been computed
- if (d.second->hashes.empty()) {
- d.second->hash();
- }
- {
- ReadLock rl(&(d.second->d_lock));
- const auto& server = d.second;
- // we want to keep track of the last hash
- if (min > *(server->hashes.begin())) {
- min = *(server->hashes.begin());
- first = server;
- }
-
- auto hash_it = server->hashes.lower_bound(qhash);
- if (hash_it != server->hashes.end()) {
- if (*hash_it < sel) {
- sel = *hash_it;
- ret = server;
- }
- }
- }
- }
- }
- if (ret != nullptr) {
- return ret;
- }
- if (first != nullptr) {
- return first;
- }
- return shared_ptr<DownstreamState>();
-}
-
-shared_ptr<DownstreamState> roundrobin(const NumberedServerVector& servers, const DNSQuestion* dq)
-{
- NumberedServerVector poss;
-
- for(auto& d : servers) {
- if(d.second->isUp()) {
- poss.push_back(d);
- }
- }
-
- const auto *res=&poss;
- if(poss.empty() && !g_roundrobinFailOnNoServer)
- res = &servers;
-
- if(res->empty())
- return shared_ptr<DownstreamState>();
-
- static unsigned int counter;
-
- return (*res)[(counter++) % res->size()].second;
-}
-
ComboAddress g_serverControl{"127.0.0.1:5199"};
-std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string& poolName)
-{
- std::shared_ptr<ServerPool> pool;
- pools_t::iterator it = pools.find(poolName);
- if (it != pools.end()) {
- pool = it->second;
- }
- else {
- if (!poolName.empty())
- vinfolog("Creating pool %s", poolName);
- pool = std::make_shared<ServerPool>();
- pools.insert(std::pair<std::string,std::shared_ptr<ServerPool> >(poolName, pool));
- }
- return pool;
-}
-
-void setPoolPolicy(pools_t& pools, const string& poolName, std::shared_ptr<ServerPolicy> policy)
-{
- std::shared_ptr<ServerPool> pool = createPoolIfNotExists(pools, poolName);
- if (!poolName.empty()) {
- vinfolog("Setting pool %s server selection policy to %s", poolName, policy->name);
- } else {
- vinfolog("Setting default pool server selection policy to %s", policy->name);
- }
- pool->policy = policy;
-}
-
-void addServerToPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server)
-{
- std::shared_ptr<ServerPool> pool = createPoolIfNotExists(pools, poolName);
- if (!poolName.empty()) {
- vinfolog("Adding server to pool %s", poolName);
- } else {
- vinfolog("Adding server to default pool");
- }
- pool->addServer(server);
-}
-
-void removeServerFromPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server)
-{
- std::shared_ptr<ServerPool> pool = getPool(pools, poolName);
-
- if (!poolName.empty()) {
- vinfolog("Removing server from pool %s", poolName);
- }
- else {
- vinfolog("Removing server from default pool");
- }
-
- pool->removeServer(server);
-}
-
-std::shared_ptr<ServerPool> getPool(const pools_t& pools, const std::string& poolName)
-{
- pools_t::const_iterator it = pools.find(poolName);
-
- if (it == pools.end()) {
- throw std::out_of_range("No pool named " + poolName);
- }
-
- return it->second;
-}
-
-NumberedServerVector getDownstreamCandidates(const pools_t& pools, const std::string& poolName)
-{
- std::shared_ptr<ServerPool> pool = getPool(pools, poolName);
- return pool->getServers();
-}
static void spoofResponseFromString(DNSQuestion& dq, const string& spoofContent, bool raw)
{
policy = *(serverPool->policy);
}
auto servers = serverPool->getServers();
- if (policy.isLua) {
- std::lock_guard<std::mutex> lock(g_luamutex);
- selectedBackend = policy.policy(servers, &dq);
- }
- else {
- selectedBackend = policy.policy(servers, &dq);
- }
+ selectedBackend = getSelectedBackendFromPolicy(policy, servers, dq);
uint16_t cachedResponseSize = dq.size;
uint32_t allowExpired = selectedBackend ? 0 : g_staleCacheEntriesTTL;
--dss->outstanding;
++g_stats.downstreamTimeouts; // this is an 'actively' discovered timeout
vinfolog("Had a downstream timeout from %s (%s) for query for %s|%s from %s",
- dss->remote.toStringWithPort(), dss->name,
+ dss->remote.toStringWithPort(), dss->getName(),
ids.qname.toLogString(), QType(ids.qtype).getName(), ids.origRemote.toStringWithPort());
struct timespec ts;
#include "dnscrypt.hh"
#include "dnsdist-cache.hh"
#include "dnsdist-dynbpf.hh"
+#include "dnsdist-lbpolicies.hh"
#include "dnsname.hh"
#include "doh.hh"
#include "ednsoptions.hh"
DNSQuestion& operator=(const DNSQuestion&) = delete;
DNSQuestion(DNSQuestion&&) = default;
+ std::string getTrailingData() const;
+ bool setTrailingData(const std::string&);
+
#ifdef HAVE_PROTOBUF
boost::optional<boost::uuids::uuid> uniqueId;
#endif
pthread_rwlock_destroy(&d_lock);
}
boost::uuids::uuid id;
- std::set<unsigned int> hashes;
+ std::vector<unsigned int> hashes;
mutable pthread_rwlock_t d_lock;
std::vector<int> sockets;
const std::string sourceItfName;
std::atomic<double> tcpAvgQueriesPerConnection{0.0};
/* in ms */
std::atomic<double> tcpAvgConnectionDuration{0.0};
- string name;
size_t socketsOffset{0};
double queryLoad{0.0};
double dropRate{0.0};
void setDown() { availability = Availability::Down; }
void setAuto() { availability = Availability::Auto; }
string getName() const {
- if (name.empty()) {
- return remote.toStringWithPort();
- }
return name;
}
string getNameWithAddr() const {
- if (name.empty()) {
- return remote.toStringWithPort();
- }
- return name + " (" + remote.toStringWithPort()+ ")";
+ return nameWithAddr;
+ }
+ void setName(const std::string& newName)
+ {
+ name = newName;
+ nameWithAddr = newName.empty() ? remote.toStringWithPort() : (name + " (" + remote.toStringWithPort()+ ")");
}
+
string getStatus() const
{
string status;
tcpAvgQueriesPerConnection = (99.0 * tcpAvgQueriesPerConnection / 100.0) + (nbQueries / 100.0);
tcpAvgConnectionDuration = (99.0 * tcpAvgConnectionDuration / 100.0) + (durationMs / 100.0);
}
+private:
+ std::string name;
+ std::string nameWithAddr;
};
using servers_t =vector<std::shared_ptr<DownstreamState>>;
-template <class T> using NumberedVector = std::vector<std::pair<unsigned int, T> >;
-
void responderThread(std::shared_ptr<DownstreamState> state);
extern std::mutex g_luamutex;
extern LuaContext g_lua;
mutable std::atomic<uint64_t> d_matches{0};
};
-using NumberedServerVector = NumberedVector<shared_ptr<DownstreamState>>;
-typedef std::function<shared_ptr<DownstreamState>(const NumberedServerVector& servers, const DNSQuestion*)> policyfunc_t;
-
-struct ServerPolicy
-{
- string name;
- policyfunc_t policy;
- bool isLua;
- std::string toString() const {
- return string("ServerPolicy") + (isLua ? " (Lua)" : "") + " \"" + name + "\"";
- }
-};
-
struct ServerPool
{
ServerPool()
return count;
}
- NumberedVector<shared_ptr<DownstreamState>> getServers()
+ ServerPolicy::NumberedServerVector getServers()
{
- NumberedVector<shared_ptr<DownstreamState>> result;
+ ServerPolicy::NumberedServerVector result;
{
ReadLock rl(&d_lock);
result = d_servers;
}
private:
- NumberedVector<shared_ptr<DownstreamState>> d_servers;
+ ServerPolicy::NumberedServerVector d_servers;
pthread_rwlock_t d_lock;
bool d_useECS{false};
};
-using pools_t=map<std::string,std::shared_ptr<ServerPool>>;
-void setPoolPolicy(pools_t& pools, const string& poolName, std::shared_ptr<ServerPolicy> policy);
-void addServerToPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server);
-void removeServerFromPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server);
struct CarbonConfig
{
struct dnsheader;
void controlThread(int fd, ComboAddress local);
-std::shared_ptr<ServerPool> getPool(const pools_t& pools, const std::string& poolName);
-std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string& poolName);
-NumberedServerVector getDownstreamCandidates(const pools_t& pools, const std::string& poolName);
-
-std::shared_ptr<DownstreamState> firstAvailable(const NumberedServerVector& servers, const DNSQuestion* dq);
-
-std::shared_ptr<DownstreamState> leastOutstanding(const NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> wrandom(const NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> whashed(const NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> chashed(const NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> roundrobin(const NumberedServerVector& servers, const DNSQuestion* dq);
+vector<std::function<void(void)>> setupLua(bool client, const std::string& config);
struct WebserverConfig
{
/configure
/depcomp
/dnsdist.1
+/dnsdist-lua-ffi-interface.inc
/dnslabeltext.cc
/ext/ipcrypt/Makefile
/ext/ipcrypt/Makefile.in
dnsmessage.pb.h \
htmlfiles.h.tmp \
htmlfiles.h \
+ dnsdist-lua-ffi-interface.inc \
dnstap.pb.cc \
dnstap.pb.h
$(AM_V_GEN)$(RAGEL) $< -o dnslabeltext.cc
BUILT_SOURCES=htmlfiles.h \
+ dnsdist-lua-ffi-interface.inc \
dnslabeltext.cc
htmlfiles.h: $(srcdir)/html/*
$(AM_V_GEN)$(srcdir)/incfiles $(srcdir) > $@.tmp
@mv $@.tmp $@
+dnsdist-lua-ffi-interface.inc: dnsdist-lua-ffi-interface.h
+ echo 'R"FFIContent(' > $@
+ cat $< >> $@
+ echo ')FFIContent"' >> $@
+
SRC_JS_FILES := $(wildcard src_js/*.js)
MIN_JS_FILES := $(patsubst src_js/%.js,html/js/%.min.js,$(SRC_JS_FILES))
dns.cc dns.hh \
dnscrypt.cc dnscrypt.hh \
dnsdist.cc dnsdist.hh \
+ dnsdist-backend.cc \
dnsdist-dynbpf.cc dnsdist-dynbpf.hh \
dnsdist-cache.cc dnsdist-cache.hh \
dnsdist-carbon.cc \
dnsdist-healthchecks.cc dnsdist-healthchecks.hh \
dnsdist-idstate.cc \
dnsdist-kvs.hh dnsdist-kvs.cc \
+ dnsdist-lbpolicies.cc dnsdist-lbpolicies.hh \
dnsdist-lua.hh dnsdist-lua.cc \
dnsdist-lua-actions.cc \
dnsdist-lua-bindings.cc \
dnsdist-lua-bindings-kvs.cc \
dnsdist-lua-bindings-packetcache.cc \
dnsdist-lua-bindings-protobuf.cc \
+ dnsdist-lua-ffi.cc dnsdist-lua-ffi.hh \
+ dnsdist-lua-ffi-interface.h dnsdist-lua-ffi-interface.inc \
dnsdist-lua-inspection.cc \
dnsdist-lua-inspection-ffi.cc dnsdist-lua-inspection-ffi.hh \
dnsdist-lua-rules.cc \
test-dnsdist_cc.cc \
test-dnsdistdynblocks_hh.cc \
test-dnsdistkvs_cc.cc \
+ test-dnsdistlbpolicies_cc.cc \
test-dnsdistpacketcache_cc.cc \
test-dnsdistrings_cc.cc \
test-dnsdistrules_cc.cc \
cachecleaner.hh \
circular_buffer.hh \
dnsdist.hh \
+ dnsdist-backend.cc \
dnsdist-cache.cc dnsdist-cache.hh \
dnsdist-dynblocks.cc dnsdist-dynblocks.hh \
dnsdist-ecs.cc dnsdist-ecs.hh \
dnsdist-kvs.cc dnsdist-kvs.hh \
+ dnsdist-lbpolicies.cc dnsdist-lbpolicies.hh \
+ dnsdist-lua-ffi.cc dnsdist-lua-ffi.hh \
+ dnsdist-lua-ffi-interface.h dnsdist-lua-ffi-interface.inc \
dnsdist-rings.hh \
dnsdist-xpf.cc dnsdist-xpf.hh \
dnscrypt.cc dnscrypt.hh \
statnode.cc statnode.hh \
threadname.hh threadname.cc \
testrunner.cc \
- xpf.cc xpf.hh
+ uuid-utils.hh uuid-utils.cc \
+ xpf.cc xpf.hh \
+ ext/luawrapper/include/LuaContext.hpp
dnsdist_LDFLAGS = \
$(AM_LDFLAGS) \
testrunner_LDADD = \
$(BOOST_UNIT_TEST_FRAMEWORK_LIBS) \
- $(LIBSODIUM_LIBS) \
$(FSTRM_LIBS) \
+ $(LIBSODIUM_LIBS) \
+ $(LUA_LIBS) \
$(RT_LIBS) \
$(SANITIZER_FLAGS) \
$(LIBCAP_LIBS)
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "dnsdist.hh"
+#include "dolog.hh"
+
+bool DownstreamState::reconnect()
+{
+ std::unique_lock<std::mutex> tl(connectLock, std::try_to_lock);
+ if (!tl.owns_lock()) {
+ /* we are already reconnecting */
+ return false;
+ }
+
+ connected = false;
+ for (auto& fd : sockets) {
+ if (fd != -1) {
+ if (sockets.size() > 1) {
+ std::lock_guard<std::mutex> lock(socketsLock);
+ mplexer->removeReadFD(fd);
+ }
+ /* shutdown() is needed to wake up recv() in the responderThread */
+ shutdown(fd, SHUT_RDWR);
+ close(fd);
+ fd = -1;
+ }
+ if (!IsAnyAddress(remote)) {
+ fd = SSocket(remote.sin4.sin_family, SOCK_DGRAM, 0);
+ if (!IsAnyAddress(sourceAddr)) {
+ SSetsockopt(fd, SOL_SOCKET, SO_REUSEADDR, 1);
+ if (!sourceItfName.empty()) {
+#ifdef SO_BINDTODEVICE
+ int res = setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, sourceItfName.c_str(), sourceItfName.length());
+ if (res != 0) {
+ infolog("Error setting up the interface on backend socket '%s': %s", remote.toStringWithPort(), stringerror());
+ }
+#endif
+ }
+
+ SBind(fd, sourceAddr);
+ }
+ try {
+ SConnect(fd, remote);
+ if (sockets.size() > 1) {
+ std::lock_guard<std::mutex> lock(socketsLock);
+ mplexer->addReadFD(fd, [](int, boost::any) {});
+ }
+ connected = true;
+ }
+ catch(const std::runtime_error& error) {
+ infolog("Error connecting to new server with address %s: %s", remote.toStringWithPort(), error.what());
+ connected = false;
+ break;
+ }
+ }
+ }
+
+ /* if at least one (re-)connection failed, close all sockets */
+ if (!connected) {
+ for (auto& fd : sockets) {
+ if (fd != -1) {
+ if (sockets.size() > 1) {
+ std::lock_guard<std::mutex> lock(socketsLock);
+ mplexer->removeReadFD(fd);
+ }
+ /* shutdown() is needed to wake up recv() in the responderThread */
+ shutdown(fd, SHUT_RDWR);
+ close(fd);
+ fd = -1;
+ }
+ }
+ }
+
+ return connected;
+}
+void DownstreamState::hash()
+{
+ vinfolog("Computing hashes for id=%s and weight=%d", id, weight);
+ auto w = weight;
+ WriteLock wl(&d_lock);
+ hashes.clear();
+ hashes.reserve(w);
+ while (w > 0) {
+ std::string uuid = boost::str(boost::format("%s-%d") % id % w);
+ unsigned int wshash = burtleCI(reinterpret_cast<const unsigned char*>(uuid.c_str()), uuid.size(), g_hashperturb);
+ hashes.push_back(wshash);
+ --w;
+ }
+ std::sort(hashes.begin(), hashes.end());
+}
+
+void DownstreamState::setId(const boost::uuids::uuid& newId)
+{
+ id = newId;
+ // compute hashes only if already done
+ if (!hashes.empty()) {
+ hash();
+ }
+}
+
+void DownstreamState::setWeight(int newWeight)
+{
+ if (newWeight < 1) {
+ errlog("Error setting server's weight: downstream weight value must be greater than 0.");
+ return ;
+ }
+ weight = newWeight;
+ if (!hashes.empty()) {
+ hash();
+ }
+}
+
+DownstreamState::DownstreamState(const ComboAddress& remote_, const ComboAddress& sourceAddr_, unsigned int sourceItf_, const std::string& sourceItfName_, size_t numberOfSockets, bool connect=true): sourceItfName(sourceItfName_), remote(remote_), sourceAddr(sourceAddr_), sourceItf(sourceItf_), name(remote_.toStringWithPort()), nameWithAddr(remote_.toStringWithPort())
+{
+ pthread_rwlock_init(&d_lock, nullptr);
+ id = getUniqueID();
+ threadStarted.clear();
+
+ mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
+
+ sockets.resize(numberOfSockets);
+ for (auto& fd : sockets) {
+ fd = -1;
+ }
+
+ if (connect && !IsAnyAddress(remote)) {
+ reconnect();
+ idStates.resize(g_maxOutstanding);
+ sw.start();
+ }
+
+}
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "dnsdist.hh"
+#include "dnsdist-lbpolicies.hh"
+#include "dnsdist-lua-ffi.hh"
+#include "dolog.hh"
+
+GlobalStateHolder<ServerPolicy> g_policy;
+bool g_roundrobinFailOnNoServer{false};
+
+// get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest
+shared_ptr<DownstreamState> leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+{
+ if (servers.size() == 1 && servers[0].second->isUp()) {
+ return servers[0].second;
+ }
+
+ vector<pair<tuple<int,int,double>, size_t>> poss;
+ /* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort,
+ which would suck royally and could even lead to crashes. So first we snapshot on what we sort, and then we sort */
+ poss.reserve(servers.size());
+ size_t position = 0;
+ for(const auto& d : servers) {
+ if(d.second->isUp()) {
+ poss.emplace_back(make_tuple(d.second->outstanding.load(), d.second->order, d.second->latencyUsec), position);
+ }
+ ++position;
+ }
+
+ if (poss.empty()) {
+ return shared_ptr<DownstreamState>();
+ }
+
+ nth_element(poss.begin(), poss.begin(), poss.end(), [](const decltype(poss)::value_type& a, const decltype(poss)::value_type& b) { return a.first < b.first; });
+ return servers.at(poss.begin()->second).second;
+}
+
+shared_ptr<DownstreamState> firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+{
+ for(auto& d : servers) {
+ if(d.second->isUp() && d.second->qps.check())
+ return d.second;
+ }
+ return leastOutstanding(servers, dq);
+}
+
+static shared_ptr<DownstreamState> valrandom(unsigned int val, const ServerPolicy::NumberedServerVector& servers)
+{
+ vector<pair<int, size_t>> poss;
+ poss.reserve(servers.size());
+ int sum = 0;
+ int max = std::numeric_limits<int>::max();
+
+ for(const auto& d : servers) { // w=1, w=10 -> 1, 11
+ if(d.second->isUp()) {
+ // Don't overflow sum when adding high weights
+ if(d.second->weight > max - sum) {
+ sum = max;
+ } else {
+ sum += d.second->weight;
+ }
+
+ poss.emplace_back(sum, d.first);
+ }
+ }
+
+ // Catch poss & sum are empty to avoid SIGFPE
+ if (poss.empty()) {
+ return shared_ptr<DownstreamState>();
+ }
+
+ int r = val % sum;
+ auto p = upper_bound(poss.begin(), poss.end(),r, [](int r_, const decltype(poss)::value_type& a) { return r_ < a.first;});
+ if (p == poss.end()) {
+ return shared_ptr<DownstreamState>();
+ }
+
+ return servers.at(p->second - 1).second;
+}
+
+shared_ptr<DownstreamState> wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+{
+ return valrandom(random(), servers);
+}
+
+uint32_t g_hashperturb;
+double g_consistentHashBalancingFactor = 0;
+
+shared_ptr<DownstreamState> whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash)
+{
+ return valrandom(hash, servers);
+}
+
+shared_ptr<DownstreamState> whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+{
+ return whashedFromHash(servers, dq->qname->hash(g_hashperturb));
+}
+
+shared_ptr<DownstreamState> chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t qhash)
+{
+ unsigned int sel = std::numeric_limits<unsigned int>::max();
+ unsigned int min = std::numeric_limits<unsigned int>::max();
+ shared_ptr<DownstreamState> ret = nullptr, first = nullptr;
+
+ double targetLoad = std::numeric_limits<double>::max();
+ if (g_consistentHashBalancingFactor > 0) {
+ /* we start with one, representing the query we are currently handling */
+ double currentLoad = 1;
+ for (const auto& pair : servers) {
+ currentLoad += pair.second->outstanding;
+ }
+ targetLoad = (currentLoad / servers.size()) * g_consistentHashBalancingFactor;
+ }
+
+ for (const auto& d: servers) {
+ if (d.second->isUp() && d.second->outstanding <= targetLoad) {
+ // make sure hashes have been computed
+ if (d.second->hashes.empty()) {
+ d.second->hash();
+ }
+ {
+ ReadLock rl(&(d.second->d_lock));
+ const auto& server = d.second;
+ // we want to keep track of the last hash
+ if (min > *(server->hashes.begin())) {
+ min = *(server->hashes.begin());
+ first = server;
+ }
+
+ auto hash_it = std::lower_bound(server->hashes.begin(), server->hashes.end(), qhash);
+ if (hash_it != server->hashes.end()) {
+ if (*hash_it < sel) {
+ sel = *hash_it;
+ ret = server;
+ }
+ }
+ }
+ }
+ }
+ if (ret != nullptr) {
+ return ret;
+ }
+ if (first != nullptr) {
+ return first;
+ }
+ return shared_ptr<DownstreamState>();
+}
+
+shared_ptr<DownstreamState> chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+{
+ return chashedFromHash(servers, dq->qname->hash(g_hashperturb));
+}
+
+shared_ptr<DownstreamState> roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+{
+ ServerPolicy::NumberedServerVector poss;
+
+ for(auto& d : servers) {
+ if(d.second->isUp()) {
+ poss.push_back(d);
+ }
+ }
+
+ const auto *res=&poss;
+ if(poss.empty() && !g_roundrobinFailOnNoServer)
+ res = &servers;
+
+ if(res->empty())
+ return shared_ptr<DownstreamState>();
+
+ static unsigned int counter;
+ return (*res)[(counter++) % res->size()].second;
+}
+
+ServerPolicy::NumberedServerVector getDownstreamCandidates(const pools_t& pools, const std::string& poolName)
+{
+ std::shared_ptr<ServerPool> pool = getPool(pools, poolName);
+ return pool->getServers();
+}
+
+std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string& poolName)
+{
+ std::shared_ptr<ServerPool> pool;
+ pools_t::iterator it = pools.find(poolName);
+ if (it != pools.end()) {
+ pool = it->second;
+ }
+ else {
+ if (!poolName.empty())
+ vinfolog("Creating pool %s", poolName);
+ pool = std::make_shared<ServerPool>();
+ pools.insert(std::pair<std::string,std::shared_ptr<ServerPool> >(poolName, pool));
+ }
+ return pool;
+}
+
+void setPoolPolicy(pools_t& pools, const string& poolName, std::shared_ptr<ServerPolicy> policy)
+{
+ std::shared_ptr<ServerPool> pool = createPoolIfNotExists(pools, poolName);
+ if (!poolName.empty()) {
+ vinfolog("Setting pool %s server selection policy to %s", poolName, policy->name);
+ } else {
+ vinfolog("Setting default pool server selection policy to %s", policy->name);
+ }
+ pool->policy = policy;
+}
+
+void addServerToPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server)
+{
+ std::shared_ptr<ServerPool> pool = createPoolIfNotExists(pools, poolName);
+ if (!poolName.empty()) {
+ vinfolog("Adding server to pool %s", poolName);
+ } else {
+ vinfolog("Adding server to default pool");
+ }
+ pool->addServer(server);
+}
+
+void removeServerFromPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server)
+{
+ std::shared_ptr<ServerPool> pool = getPool(pools, poolName);
+
+ if (!poolName.empty()) {
+ vinfolog("Removing server from pool %s", poolName);
+ }
+ else {
+ vinfolog("Removing server from default pool");
+ }
+
+ pool->removeServer(server);
+}
+
+std::shared_ptr<ServerPool> getPool(const pools_t& pools, const std::string& poolName)
+{
+ pools_t::const_iterator it = pools.find(poolName);
+
+ if (it == pools.end()) {
+ throw std::out_of_range("No pool named " + poolName);
+ }
+
+ return it->second;
+}
+
+std::shared_ptr<DownstreamState> getSelectedBackendFromPolicy(const ServerPolicy& policy, const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq)
+{
+ std::shared_ptr<DownstreamState> selectedBackend{nullptr};
+
+ if (policy.isLua) {
+ if (!policy.isFFI) {
+ std::lock_guard<std::mutex> lock(g_luamutex);
+ selectedBackend = policy.policy(servers, &dq);
+ }
+ else {
+ dnsdist_ffi_dnsquestion_t dnsq(&dq);
+ dnsdist_ffi_servers_list_t serversList(servers);
+ unsigned int selected = 0;
+ {
+ std::lock_guard<std::mutex> lock(g_luamutex);
+ selected = policy.ffipolicy(&serversList, &dnsq);
+ }
+ selectedBackend = servers.at(selected).second;
+ }
+ }
+ else {
+ selectedBackend = policy.policy(servers, &dq);
+ }
+
+ return selectedBackend;
+}
--- /dev/null
+../dnsdist-lbpolicies.hh
\ No newline at end of file
}
if (keyVar.type() == typeid(ComboAddress)) {
- const auto ca = *boost::get<ComboAddress>(&keyVar);
+ const auto ca = boost::get<ComboAddress>(&keyVar);
KeyValueLookupKeySourceIP lookup;
- for (const auto& key : lookup.getKeys(ca)) {
+ for (const auto& key : lookup.getKeys(*ca)) {
if (kvs->getValue(key, result)) {
return result;
}
}
}
else if (keyVar.type() == typeid(DNSName)) {
- DNSName dn = *boost::get<DNSName>(&keyVar);
+ const DNSName* dn = boost::get<DNSName>(&keyVar);
KeyValueLookupKeyQName lookup(wireFormat ? *wireFormat : true);
- for (const auto& key : lookup.getKeys(dn)) {
+ for (const auto& key : lookup.getKeys(*dn)) {
if (kvs->getValue(key, result)) {
return result;
}
}
}
else if (keyVar.type() == typeid(std::string)) {
- std::string keyStr = *boost::get<std::string>(&keyVar);
- kvs->getValue(keyStr, result);
+ const std::string* keyStr = boost::get<std::string>(&keyVar);
+ kvs->getValue(*keyStr, result);
}
return result;
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+/* we don't use a guard (C++ pragma once or even #ifndef because this file (the .inc version)
+ is passed to the Lua FFI wrapper which doesn't support it */
+
+typedef struct dnsdist_ffi_dnsquestion_t dnsdist_ffi_dnsquestion_t;
+typedef struct dnsdist_ffi_servers_list_t dnsdist_ffi_servers_list_t;
+typedef struct dnsdist_ffi_server_t dnsdist_ffi_server_t;
+
+typedef struct dnsdist_ffi_ednsoption {
+ uint16_t optionCode;
+ uint16_t len;
+ const void* data;
+} dnsdist_ffi_ednsoption_t;
+
+typedef struct dnsdist_ffi_http_header {
+ const char* name;
+ const char* value;
+} dnsdist_ffi_http_header_t;
+
+typedef struct dnsdist_ffi_tag {
+ const char* name;
+ const char* value;
+} dnsdist_ffi_tag_t;
+
+
+void dnsdist_ffi_dnsquestion_get_localaddr(const dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize) __attribute__ ((visibility ("default")));
+uint16_t dnsdist_ffi_dnsquestion_get_local_port(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_get_remoteaddr(const dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_get_masked_remoteaddr(dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize, uint8_t bits) __attribute__ ((visibility ("default")));
+uint16_t dnsdist_ffi_dnsquestion_get_remote_port(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_get_qname_raw(const dnsdist_ffi_dnsquestion_t* dq, const char** qname, size_t* qnameSize) __attribute__ ((visibility ("default")));
+size_t dnsdist_ffi_dnsquestion_get_qname_hash(const dnsdist_ffi_dnsquestion_t* dq, size_t init);
+uint16_t dnsdist_ffi_dnsquestion_get_qtype(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+uint16_t dnsdist_ffi_dnsquestion_get_qclass(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+int dnsdist_ffi_dnsquestion_get_rcode(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+void* dnsdist_ffi_dnsquestion_get_header(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+uint16_t dnsdist_ffi_dnsquestion_get_len(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+size_t dnsdist_ffi_dnsquestion_get_size(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+uint8_t dnsdist_ffi_dnsquestion_get_opcode(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_get_tcp(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_get_skip_cache(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_get_use_ecs(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_get_add_xpf(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_get_ecs_override(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+uint16_t dnsdist_ffi_dnsquestion_get_ecs_prefix_length(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_is_temp_failure_ttl_set(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+uint32_t dnsdist_ffi_dnsquestion_get_temp_failure_ttl(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_get_do(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_get_sni(const dnsdist_ffi_dnsquestion_t* dq, const char** sni, size_t* sniSize) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_dnsquestion_get_tag(const dnsdist_ffi_dnsquestion_t* dq, const char* label) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_dnsquestion_get_http_path(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_dnsquestion_get_http_query_string(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_dnsquestion_get_http_host(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_dnsquestion_get_http_scheme(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+
+// returns the length of the resulting 'out' array. 'out' is not set if the length is 0
+size_t dnsdist_ffi_dnsquestion_get_edns_options(dnsdist_ffi_dnsquestion_t* ref, const dnsdist_ffi_ednsoption_t** out) __attribute__ ((visibility ("default")));
+size_t dnsdist_ffi_dnsquestion_get_http_headers(dnsdist_ffi_dnsquestion_t* ref, const dnsdist_ffi_http_header_t** out) __attribute__ ((visibility ("default")));
+size_t dnsdist_ffi_dnsquestion_get_tag_array(dnsdist_ffi_dnsquestion_t* ref, const dnsdist_ffi_tag_t** out) __attribute__ ((visibility ("default")));
+
+void dnsdist_ffi_dnsquestion_set_result(dnsdist_ffi_dnsquestion_t* dq, const char* str, size_t strSize) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_rcode(dnsdist_ffi_dnsquestion_t* dq, int rcode) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_len(dnsdist_ffi_dnsquestion_t* dq, uint16_t len) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_skip_cache(dnsdist_ffi_dnsquestion_t* dq, bool skipCache) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_use_ecs(dnsdist_ffi_dnsquestion_t* dq, bool useECS) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_ecs_override(dnsdist_ffi_dnsquestion_t* dq, bool ecsOverride) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_ecs_prefix_length(dnsdist_ffi_dnsquestion_t* dq, uint16_t ecsPrefixLength) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq, uint32_t tempFailureTTL) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_unset_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_tag(dnsdist_ffi_dnsquestion_t* dq, const char* label, const char* value) __attribute__ ((visibility ("default")));
+
+void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, uint16_t statusCode, const char* body, const char* contentType) __attribute__ ((visibility ("default")));
+
+size_t dnsdist_ffi_dnsquestion_get_trailing_data(dnsdist_ffi_dnsquestion_t* dq, const char** out) __attribute__ ((visibility ("default")));
+
+bool dnsdist_ffi_dnsquestion_set_trailing_data(dnsdist_ffi_dnsquestion_t* dq, const char* data, size_t dataLen) __attribute__ ((visibility ("default")));
+
+void dnsdist_ffi_dnsquestion_send_trap(dnsdist_ffi_dnsquestion_t* dq, const char* reason, size_t reasonLen) __attribute__ ((visibility ("default")));
+
+typedef struct dnsdist_ffi_servers_list_t dnsdist_ffi_servers_list_t;
+typedef struct dnsdist_ffi_server_t dnsdist_ffi_server_t;
+
+size_t dnsdist_ffi_servers_list_get_count(const dnsdist_ffi_servers_list_t* list) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_servers_list_get_server(const dnsdist_ffi_servers_list_t* list, size_t idx, const dnsdist_ffi_server_t** out) __attribute__ ((visibility ("default")));
+size_t dnsdist_ffi_servers_list_chashed(const dnsdist_ffi_servers_list_t* list, const dnsdist_ffi_dnsquestion_t* dq, size_t hash) __attribute__ ((visibility ("default")));
+size_t dnsdist_ffi_servers_list_whashed(const dnsdist_ffi_servers_list_t* list, const dnsdist_ffi_dnsquestion_t* dq, size_t hash) __attribute__ ((visibility ("default")));
+
+uint64_t dnsdist_ffi_server_get_outstanding(const dnsdist_ffi_server_t* server) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_server_is_up(const dnsdist_ffi_server_t* server) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_server_get_name(const dnsdist_ffi_server_t* server) __attribute__ ((visibility ("default")));
+const char* dnsdist_ffi_server_get_name_with_addr(const dnsdist_ffi_server_t* server) __attribute__ ((visibility ("default")));
+int dnsdist_ffi_server_get_weight(const dnsdist_ffi_server_t* server) __attribute__ ((visibility ("default")));
+int dnsdist_ffi_server_get_order(const dnsdist_ffi_server_t* server) __attribute__ ((visibility ("default")));
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "dnsdist-lua-ffi.hh"
+#include "dnsdist-ecs.hh"
+
+uint16_t dnsdist_ffi_dnsquestion_get_qtype(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->qtype;
+}
+
+uint16_t dnsdist_ffi_dnsquestion_get_qclass(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->qclass;
+}
+
+static void dnsdist_ffi_comboaddress_to_raw(const ComboAddress& ca, const void** addr, size_t* addrSize)
+{
+ if (ca.isIPv4()) {
+ *addr = &ca.sin4.sin_addr.s_addr;
+ *addrSize = sizeof(ca.sin4.sin_addr.s_addr);
+ }
+ else {
+ *addr = &ca.sin6.sin6_addr.s6_addr;
+ *addrSize = sizeof(ca.sin6.sin6_addr.s6_addr);
+ }
+}
+
+void dnsdist_ffi_dnsquestion_get_localaddr(const dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize)
+{
+ dnsdist_ffi_comboaddress_to_raw(*dq->dq->local, addr, addrSize);
+}
+
+void dnsdist_ffi_dnsquestion_get_remoteaddr(const dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize)
+{
+ dnsdist_ffi_comboaddress_to_raw(*dq->dq->remote, addr, addrSize);
+}
+
+void dnsdist_ffi_dnsquestion_get_masked_remoteaddr(dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize, uint8_t bits)
+{
+ dq->maskedRemote = Netmask(*dq->dq->remote, bits).getMaskedNetwork();
+ dnsdist_ffi_comboaddress_to_raw(dq->maskedRemote, addr, addrSize);
+}
+
+uint16_t dnsdist_ffi_dnsquestion_get_local_port(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->local->getPort();
+}
+
+uint16_t dnsdist_ffi_dnsquestion_get_remote_port(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->remote->getPort();
+}
+
+void dnsdist_ffi_dnsquestion_get_qname_raw(const dnsdist_ffi_dnsquestion_t* dq, const char** qname, size_t* qnameSize)
+{
+ const auto& storage = dq->dq->qname->getStorage();
+ *qname = storage.data();
+ *qnameSize = storage.size();
+}
+
+size_t dnsdist_ffi_dnsquestion_get_qname_hash(const dnsdist_ffi_dnsquestion_t* dq, size_t init)
+{
+ return dq->dq->qname->hash(init);
+}
+
+int dnsdist_ffi_dnsquestion_get_rcode(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->dh->rcode;
+}
+
+void* dnsdist_ffi_dnsquestion_get_header(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->dh;
+}
+
+uint16_t dnsdist_ffi_dnsquestion_get_len(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->len;
+}
+
+size_t dnsdist_ffi_dnsquestion_get_size(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->size;
+}
+
+uint8_t dnsdist_ffi_dnsquestion_get_opcode(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->dh->opcode;
+}
+
+bool dnsdist_ffi_dnsquestion_get_tcp(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->tcp;
+}
+
+bool dnsdist_ffi_dnsquestion_get_skip_cache(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->skipCache;
+}
+
+bool dnsdist_ffi_dnsquestion_get_use_ecs(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->useECS;
+}
+
+bool dnsdist_ffi_dnsquestion_get_add_xpf(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->addXPF;
+}
+
+bool dnsdist_ffi_dnsquestion_get_ecs_override(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->ecsOverride;
+}
+
+uint16_t dnsdist_ffi_dnsquestion_get_ecs_prefix_length(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->ecsPrefixLength;
+}
+
+bool dnsdist_ffi_dnsquestion_is_temp_failure_ttl_set(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return dq->dq->tempFailureTTL != boost::none;
+}
+
+uint32_t dnsdist_ffi_dnsquestion_get_temp_failure_ttl(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ if (dq->dq->tempFailureTTL) {
+ return *dq->dq->tempFailureTTL;
+ }
+ return 0;
+}
+
+bool dnsdist_ffi_dnsquestion_get_do(const dnsdist_ffi_dnsquestion_t* dq)
+{
+ return getEDNSZ(*dq->dq) & EDNS_HEADER_FLAG_DO;
+}
+
+void dnsdist_ffi_dnsquestion_get_sni(const dnsdist_ffi_dnsquestion_t* dq, const char** sni, size_t* sniSize)
+{
+ *sniSize = dq->dq->sni.size();
+ *sni = dq->dq->sni.c_str();
+}
+
+const char* dnsdist_ffi_dnsquestion_get_tag(const dnsdist_ffi_dnsquestion_t* dq, const char* label)
+{
+ const char * result = nullptr;
+
+ if (dq->dq->qTag != nullptr) {
+ const auto it = dq->dq->qTag->find(label);
+ if (it != dq->dq->qTag->cend()) {
+ result = it->second.c_str();
+ }
+ }
+
+ return result;
+}
+
+const char* dnsdist_ffi_dnsquestion_get_http_path(dnsdist_ffi_dnsquestion_t* dq)
+{
+ if (!dq->httpPath) {
+ if (dq->dq->du == nullptr) {
+ return nullptr;
+ }
+#ifdef HAVE_DNS_OVER_HTTPS
+ dq->httpPath = dq->dq->du->getHTTPPath();
+#endif /* HAVE_DNS_OVER_HTTPS */
+ }
+ if (dq->httpPath) {
+ return dq->httpPath->c_str();
+ }
+ return nullptr;
+}
+
+const char* dnsdist_ffi_dnsquestion_get_http_query_string(dnsdist_ffi_dnsquestion_t* dq)
+{
+ if (!dq->httpQueryString) {
+ if (dq->dq->du == nullptr) {
+ return nullptr;
+ }
+#ifdef HAVE_DNS_OVER_HTTPS
+ dq->httpQueryString = dq->dq->du->getHTTPQueryString();
+#endif /* HAVE_DNS_OVER_HTTPS */
+ }
+ if (dq->httpQueryString) {
+ return dq->httpQueryString->c_str();
+ }
+ return nullptr;
+}
+
+const char* dnsdist_ffi_dnsquestion_get_http_host(dnsdist_ffi_dnsquestion_t* dq)
+{
+ if (!dq->httpHost) {
+ if (dq->dq->du == nullptr) {
+ return nullptr;
+ }
+#ifdef HAVE_DNS_OVER_HTTPS
+ dq->httpHost = dq->dq->du->getHTTPHost();
+#endif /* HAVE_DNS_OVER_HTTPS */
+ }
+ if (dq->httpHost) {
+ return dq->httpHost->c_str();
+ }
+ return nullptr;
+}
+
+const char* dnsdist_ffi_dnsquestion_get_http_scheme(dnsdist_ffi_dnsquestion_t* dq)
+{
+ if (!dq->httpScheme) {
+ if (dq->dq->du == nullptr) {
+ return nullptr;
+ }
+#ifdef HAVE_DNS_OVER_HTTPS
+ dq->httpScheme = dq->dq->du->getHTTPScheme();
+#endif /* HAVE_DNS_OVER_HTTPS */
+ }
+ if (dq->httpScheme) {
+ return dq->httpScheme->c_str();
+ }
+ return nullptr;
+}
+
+static void fill_edns_option(const EDNSOptionViewValue& value, dnsdist_ffi_ednsoption_t& option)
+{
+ option.len = value.size;
+ option.data = nullptr;
+
+ if (value.size > 0) {
+ option.data = value.content;
+ }
+}
+
+// returns the length of the resulting 'out' array. 'out' is not set if the length is 0
+size_t dnsdist_ffi_dnsquestion_get_edns_options(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_ednsoption_t** out)
+{
+ if (dq->dq->ednsOptions == nullptr) {
+ parseEDNSOptions(*(dq->dq));
+ }
+
+ size_t totalCount = 0;
+ for (const auto& option : *dq->dq->ednsOptions) {
+ totalCount += option.second.values.size();
+ }
+
+ dq->ednsOptionsVect.clear();
+ dq->ednsOptionsVect.resize(totalCount);
+ size_t pos = 0;
+ for (const auto& option : *dq->dq->ednsOptions) {
+ for (const auto& entry : option.second.values) {
+ fill_edns_option(entry, dq->ednsOptionsVect.at(pos));
+ dq->ednsOptionsVect.at(pos).optionCode = option.first;
+ pos++;
+ }
+ }
+
+ if (totalCount > 0) {
+ *out = dq->ednsOptionsVect.data();
+ }
+
+ return totalCount;
+}
+
+size_t dnsdist_ffi_dnsquestion_get_http_headers(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_http_header_t** out)
+{
+ if (dq->dq->du == nullptr) {
+ return 0;
+ }
+
+#ifdef HAVE_DNS_OVER_HTTPS
+ dq->httpHeaders = dq->dq->du->getHTTPHeaders();
+ dq->httpHeadersVect.clear();
+ dq->httpHeadersVect.resize(dq->httpHeaders.size());
+ size_t pos = 0;
+ for (const auto& header : dq->httpHeaders) {
+ dq->httpHeadersVect.at(pos).name = header.first.c_str();
+ dq->httpHeadersVect.at(pos).value = header.second.c_str();
+ ++pos;
+ }
+
+ if (!dq->httpHeadersVect.empty()) {
+ *out = dq->httpHeadersVect.data();
+ }
+
+ return dq->httpHeadersVect.size();
+#else
+ return 0;
+#endif
+}
+
+size_t dnsdist_ffi_dnsquestion_get_tag_array(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_tag_t** out)
+{
+ if (dq->dq->qTag == nullptr || dq->dq->qTag->size() == 0) {
+ return 0;
+ }
+
+ dq->tagsVect.clear();
+ dq->tagsVect.resize(dq->dq->qTag->size());
+ size_t pos = 0;
+
+ for (const auto& tag : *dq->dq->qTag) {
+ auto& entry = dq->tagsVect.at(pos);
+ entry.name = tag.first.c_str();
+ entry.value = tag.second.c_str();
+ ++pos;
+ }
+
+
+ if (!dq->tagsVect.empty()) {
+ *out = dq->tagsVect.data();
+ }
+
+ return dq->tagsVect.size();
+}
+
+void dnsdist_ffi_dnsquestion_set_result(dnsdist_ffi_dnsquestion_t* dq, const char* str, size_t strSize)
+{
+ dq->result = std::string(str, strSize);
+}
+
+void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, uint16_t statusCode, const char* body, const char* contentType)
+{
+ if (dq->dq->du == nullptr) {
+ return;
+ }
+
+#ifdef HAVE_DNS_OVER_HTTPS
+ dq->dq->du->setHTTPResponse(statusCode, body, contentType);
+ dq->dq->dh->qr = true;
+#endif
+}
+
+void dnsdist_ffi_dnsquestion_set_rcode(dnsdist_ffi_dnsquestion_t* dq, int rcode)
+{
+ dq->dq->dh->rcode = rcode;
+ dq->dq->dh->qr = true;
+}
+
+void dnsdist_ffi_dnsquestion_set_len(dnsdist_ffi_dnsquestion_t* dq, uint16_t len)
+{
+ dq->dq->len = len;
+}
+
+void dnsdist_ffi_dnsquestion_set_skip_cache(dnsdist_ffi_dnsquestion_t* dq, bool skipCache)
+{
+ dq->dq->skipCache = skipCache;
+}
+
+void dnsdist_ffi_dnsquestion_set_use_ecs(dnsdist_ffi_dnsquestion_t* dq, bool useECS)
+{
+ dq->dq->useECS = useECS;
+}
+
+void dnsdist_ffi_dnsquestion_set_ecs_override(dnsdist_ffi_dnsquestion_t* dq, bool ecsOverride)
+{
+ dq->dq->ecsOverride = ecsOverride;
+}
+
+void dnsdist_ffi_dnsquestion_set_ecs_prefix_length(dnsdist_ffi_dnsquestion_t* dq, uint16_t ecsPrefixLength)
+{
+ dq->dq->ecsPrefixLength = ecsPrefixLength;
+}
+
+void dnsdist_ffi_dnsquestion_set_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq, uint32_t tempFailureTTL)
+{
+ dq->dq->tempFailureTTL = tempFailureTTL;
+}
+
+void dnsdist_ffi_dnsquestion_unset_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq)
+{
+ dq->dq->tempFailureTTL = boost::none;
+}
+
+void dnsdist_ffi_dnsquestion_set_tag(dnsdist_ffi_dnsquestion_t* dq, const char* label, const char* value)
+{
+ if (!dq->dq->qTag) {
+ dq->dq->qTag = std::make_shared<QTag>();
+ }
+
+ dq->dq->qTag->insert({label, value});
+}
+
+size_t dnsdist_ffi_dnsquestion_get_trailing_data(dnsdist_ffi_dnsquestion_t* dq, const char** out)
+{
+ dq->trailingData = dq->dq->getTrailingData();
+ if (!dq->trailingData.empty()) {
+ *out = dq->trailingData.data();
+ }
+
+ return dq->trailingData.size();
+}
+
+bool dnsdist_ffi_dnsquestion_set_trailing_data(dnsdist_ffi_dnsquestion_t* dq, const char* data, size_t dataLen)
+{
+ return dq->dq->setTrailingData(std::string(data, dataLen));
+}
+
+void dnsdist_ffi_dnsquestion_send_trap(dnsdist_ffi_dnsquestion_t* dq, const char* reason, size_t reasonLen)
+{
+ if (g_snmpAgent && g_snmpTrapsEnabled) {
+ g_snmpAgent->sendDNSTrap(*dq->dq, std::string(reason, reasonLen));
+ }
+}
+
+size_t dnsdist_ffi_servers_list_get_count(const dnsdist_ffi_servers_list_t* list)
+{
+ return list->ffiServers.size();
+}
+
+void dnsdist_ffi_servers_list_get_server(const dnsdist_ffi_servers_list_t* list, size_t idx, const dnsdist_ffi_server_t** out)
+{
+ *out = &list->ffiServers.at(idx);
+}
+
+static size_t dnsdist_ffi_servers_get_index_from_server(const ServerPolicy::NumberedServerVector& servers, const std::shared_ptr<DownstreamState>& server)
+{
+ for (const auto& pair : servers) {
+ if (pair.second == server) {
+ return pair.first - 1;
+ }
+ }
+ throw std::runtime_error("Unable to find servers in server list");
+}
+
+size_t dnsdist_ffi_servers_list_chashed(const dnsdist_ffi_servers_list_t* list, const dnsdist_ffi_dnsquestion_t* dq, size_t hash)
+{
+ auto server = chashedFromHash(list->servers, hash);
+ return dnsdist_ffi_servers_get_index_from_server(list->servers, server);
+}
+
+size_t dnsdist_ffi_servers_list_whashed(const dnsdist_ffi_servers_list_t* list, const dnsdist_ffi_dnsquestion_t* dq, size_t hash)
+{
+ auto server = whashedFromHash(list->servers, hash);
+ return dnsdist_ffi_servers_get_index_from_server(list->servers, server);
+}
+
+uint64_t dnsdist_ffi_server_get_outstanding(const dnsdist_ffi_server_t* server)
+{
+ return server->server->outstanding;
+}
+
+int dnsdist_ffi_server_get_weight(const dnsdist_ffi_server_t* server)
+{
+ return server->server->weight;
+}
+
+int dnsdist_ffi_server_get_order(const dnsdist_ffi_server_t* server)
+{
+ return server->server->order;
+}
+
+bool dnsdist_ffi_server_is_up(const dnsdist_ffi_server_t* server)
+{
+ return server->server->isUp();
+}
+
+const char* dnsdist_ffi_server_get_name(const dnsdist_ffi_server_t* server)
+{
+ return server->server->getName().c_str();
+}
+
+const char* dnsdist_ffi_server_get_name_with_addr(const dnsdist_ffi_server_t* server)
+{
+ return server->server->getNameWithAddr().c_str();
+}
+
+const std::string& getLuaFFIWrappers()
+{
+ static const std::string interface =
+#include "dnsdist-lua-ffi-interface.inc"
+ ;
+ static const std::string code = R"FFICodeContent(
+ local ffi = require("ffi")
+ local C = ffi.C
+
+ ffi.cdef[[
+)FFICodeContent" + interface + R"FFICodeContent(
+ ]]
+
+)FFICodeContent";
+ return code;
+}
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include "dnsdist.hh"
+
+extern "C" {
+#include "dnsdist-lua-ffi-interface.h"
+}
+
+// dnsdist_ffi_dnsquestion_t is a lightuserdata
+template<>
+struct LuaContext::Pusher<dnsdist_ffi_dnsquestion_t*> {
+ static const int minSize = 1;
+ static const int maxSize = 1;
+
+ static PushedObject push(lua_State* state, dnsdist_ffi_dnsquestion_t* ptr) noexcept {
+ lua_pushlightuserdata(state, ptr);
+ return PushedObject{state, 1};
+ }
+};
+
+struct dnsdist_ffi_dnsquestion_t
+{
+ dnsdist_ffi_dnsquestion_t(DNSQuestion* dq_): dq(dq_)
+ {
+ }
+
+ DNSQuestion* dq{nullptr};
+ std::vector<dnsdist_ffi_ednsoption_t> ednsOptionsVect;
+ std::vector<dnsdist_ffi_http_header_t> httpHeadersVect;
+ std::vector<dnsdist_ffi_tag_t> tagsVect;
+ std::unordered_map<std::string, std::string> httpHeaders;
+ std::string trailingData;
+ ComboAddress maskedRemote;
+ boost::optional<std::string> result{boost::none};
+ boost::optional<std::string> httpPath{boost::none};
+ boost::optional<std::string> httpQueryString{boost::none};
+ boost::optional<std::string> httpHost{boost::none};
+ boost::optional<std::string> httpScheme{boost::none};
+};
+
+// dnsdist_ffi_server_t is a lightuserdata
+template<>
+struct LuaContext::Pusher<dnsdist_ffi_server_t*> {
+ static const int minSize = 1;
+ static const int maxSize = 1;
+
+ static PushedObject push(lua_State* state, dnsdist_ffi_server_t* ptr) noexcept {
+ lua_pushlightuserdata(state, ptr);
+ return PushedObject{state, 1};
+ }
+};
+
+struct dnsdist_ffi_server_t
+{
+ dnsdist_ffi_server_t(const std::shared_ptr<DownstreamState>& server_): server(server_)
+ {
+ }
+
+ const std::shared_ptr<DownstreamState>& server;
+};
+
+// dnsdist_ffi_servers_list_t is a lightuserdata
+template<>
+struct LuaContext::Pusher<dnsdist_ffi_servers_list_t*> {
+ static const int minSize = 1;
+ static const int maxSize = 1;
+
+ static PushedObject push(lua_State* state, dnsdist_ffi_servers_list_t* ptr) noexcept {
+ lua_pushlightuserdata(state, ptr);
+ return PushedObject{state, 1};
+ }
+};
+
+struct dnsdist_ffi_servers_list_t
+{
+ dnsdist_ffi_servers_list_t(const ServerPolicy::NumberedServerVector& servers_): servers(servers_)
+ {
+ ffiServers.reserve(servers.size());
+ for (const auto& server: servers) {
+ ffiServers.push_back(dnsdist_ffi_server_t(server.second));
+ }
+ }
+
+ std::vector<dnsdist_ffi_server_t> ffiServers;
+ const ServerPolicy::NumberedServerVector& servers;
+};
+
+const std::string& getLuaFFIWrappers();
#include "dnsdist.hh"
#include "dnsdist-ecs.hh"
#include "dnsdist-kvs.hh"
+#include "dnsdist-lua-ffi.hh"
+#include "dolog.hh"
#include "dnsparser.hh"
class MaxQPSIPRule : public DNSRule
std::shared_ptr<KeyValueStore> d_kvs;
std::shared_ptr<KeyValueLookupKey> d_key;
};
+
+class LuaRule : public DNSRule
+{
+public:
+ typedef std::function<bool(const DNSQuestion* dq)> func_t;
+ LuaRule(const func_t& func): d_func(func)
+ {}
+
+ bool matches(const DNSQuestion* dq) const override
+ {
+ try {
+ std::lock_guard<std::mutex> lock(g_luamutex);
+ return d_func(dq);
+ } catch (const std::exception &e) {
+ warnlog("LuaRule failed inside Lua: %s", e.what());
+ } catch (...) {
+ warnlog("LuaRule failed inside Lua: [unknown exception]");
+ }
+ return false;
+ }
+
+ string toString() const override
+ {
+ return "Lua script";
+ }
+private:
+ func_t d_func;
+};
+
+class LuaFFIRule : public DNSRule
+{
+public:
+ typedef std::function<bool(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+ LuaFFIRule(const func_t& func): d_func(func)
+ {}
+
+ bool matches(const DNSQuestion* dq) const override
+ {
+ dnsdist_ffi_dnsquestion_t dqffi(const_cast<DNSQuestion*>(dq));
+ try {
+ std::lock_guard<std::mutex> lock(g_luamutex);
+ return d_func(&dqffi);
+ } catch (const std::exception &e) {
+ warnlog("LuaFFIRule failed inside Lua: %s", e.what());
+ } catch (...) {
+ warnlog("LuaFFIRule failed inside Lua: [unknown exception]");
+ }
+ return false;
+ }
+
+ string toString() const override
+ {
+ return "Lua FFI script";
+ }
+private:
+ func_t d_func;
+};
Most of the query processing is done in C++ for maximum performance, but some operations are executed in Lua for maximum flexibility:
* Rules added by :func:`addLuaAction`
- * Server selection policies defined via :func:`setServerPolicyLua` or :func:`newServerPolicy`
+ * Server selection policies defined via :func:`setServerPolicyLua`, :func:`setServerPolicyLuaFFI` or :func:`newServerPolicy`
While Lua is fast, its use should be restricted to the strict necessary in order to achieve maximum performance, it might be worth considering using LuaJIT instead of Lua.
When Lua inspection is needed, the best course of action is to restrict the queries sent to Lua inspection by using :func:`addLuaAction` with a selector.
:param servers: A list of :class:`Server` objects
:param DNSQuestion dq: The incoming query
+ .. attribute:: ServerPolicy.ffipolicy
+
+ .. versionadded: 1.5.0
+
+ For policies implemented using the Lua FFI interface, the policy function itself.
+
+ .. attribute:: ServerPolicy.isFFI
+
+ .. versionadded: 1.5.0
+
+ Whether a Lua-based policy is implemented using the FFI interface.
+
+ .. attribute:: ServerPolicy.isLua
+
+ Whether this policy is a native (C++) policy or a Lua-based one.
+
+ .. attribute:: ServerPolicy.name
+
+ The name of the policy.
+
+ .. attribute:: ServerPolicy.policy
+
+ The policy function itself, except for FFI policies.
+
+ .. method:: Server:toString()
+
+ Return a textual representation of the policy.
+
+
Functions
---------
.. function:: setServerPolicyLua(name, function)
- Set server selection policy to one named `name`` and provided by ``function``.
+ Set server selection policy to one named ``name`` and provided by ``function``.
:param string name: name for this policy
:param string function: name of the function
+.. function:: setServerPolicyLuaFFI(name, function)
+
+ .. versionadded:: 1.5.0
+
+ Set server selection policy to one named ``name`` and provided by the FFI function ``function``.
+
+ :param string name: name for this policy
+ :param string function: name of the FFI function
+
.. function:: setServFailWhenNoServer(value)
If set, return a ServFail when no servers are available, instead of the default behaviour of dropping the query.
* :class:`Server`: generated with :func:`newServer`, represents a downstream server
* :class:`ComboAddress`: represents an IP address and port
* :class:`DNSName`: represents a domain name
+* :class:`Netmask`: represents a netmask
* :class:`NetmaskGroup`: represents a group of netmasks
* :class:`QPSLimiter`: implements a QPS-based filter
* :class:`SuffixMatchNode`: represents a group of domain suffixes for rapid testing of membership
--- /dev/null
+Netmask
+=======
+
+.. function:: newNetmask(str) -> Netmask
+ newNetmask(ca, bits) -> Netmask
+
+ .. versionadded:: 1.5.0
+
+ Returns a Netmask
+
+ :param string str: A netmask, like ``192.0.2.0/24``.
+ :param ComboAddress ca: A :class:`ComboAddress`.
+ :param int bits: The number of bits in this netmask.
+
+.. class:: Netmask
+
+ .. versionadded:: 1.5.0
+
+ Represents a netmask.
+
+ .. method:: Netmask:getBits() -> int
+
+ Return the number of bits of this netmask, for example ``24`` for ``192.0.2.0/24``.
+
+ .. method:: Netmask:getMaskedNetwork() -> ComboAddress
+
+ Return a :class:`ComboAddress` object representing the base network of this netmask object after masking any additional bits if necessary (for example ``192.0.2.0`` if the netmask was constructed with ``newNetmask('192.0.2.1/24')).
+
+ .. method:: Netmask:empty() -> bool
+
+ Return true if the netmask is empty, meaning that the netmask has not been set to a proper value.
+
+ .. method:: Netmask:isIPv4() -> bool
+
+ Return true if the netmask is an IPv4 one.
+
+ .. method:: Netmask:isIPv6() -> bool
+
+ Return true if the netmask is an IPv6 one.
+
+ .. method:: Netmask:getNetwork() -> ComboAddress
+
+ Return a :class:`ComboAddress` object representing the base network of this netmask object.
+
+ .. method:: Netmask:match(str) -> bool
+
+ Return true if the address passed in the ``str`` parameter belongs to this netmask.
+
+ :param string str: A network address, like ``192.0.2.0``.
+
+ .. method:: Netmask:toString() -> string
+
+ Return a string representation of the netmask, for example ``192.0.2.0/24``.
::
- function luarule(dq)
+ function luaaction(dq)
if(dq.qtype==DNSQType.NAPTR)
then
return DNSAction.Pool, "abuse" -- send to abuse pool
end
end
- addLuaAction(AllRule(), luarule)
+ addLuaAction(AllRule(), luaaction)
.. function:: addLuaResponseAction(DNSrule, function [, options])
:param KeyValueStore kvs: The key value store to query
:param KeyValueLookupKey lookupKey: The key to use for the lookup
+.. function:: LuaFFIRule(function)
+
+ .. versionadded:: 1.5.0
+
+ Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``.
+
+ The ``function`` should return true if the query matches, or false otherwise. If the Lua code fails, false is returned.
+
+ :param string function: the name of a Lua function
+
+.. function:: LuaRule(function)
+
+ .. versionadded:: 1.5.0
+
+ Invoke a Lua function that accepts a :class:`DNSQuestion` object.
+
+ The ``function`` should return true if the query matches, or false otherwise. If the Lua code fails, false is returned.
+
+ :param string function: the name of a Lua function
+
.. function:: MaxQPSIPRule(qps[, v4Mask[, v6Mask[, burst[, expiration[, cleanupDelay[, scanFraction]]]]]])
.. versionchanged:: 1.3.1
:param string function: the name of a Lua function
+.. function:: LuaFFIAction(function)
+
+ .. versionadded:: 1.5.0
+
+ Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``.
+
+ The ``function`` should return a :ref:`DNSAction`. If the Lua code fails, ServFail is returned.
+
+ :param string function: the name of a Lua function
+
+.. function:: LuaFFIResponseAction(function)
+
+ .. versionadded:: 1.5.0
+
+ Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``.
+
+ The ``function`` should return a :ref:`DNSResponseAction`. If the Lua code fails, ServFail is returned.
+
+ :param string function: the name of a Lua function
+
.. function:: LuaResponseAction(function)
Invoke a Lua function that accepts a :class:`DNSResponse`.
--- /dev/null
+
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+
+#include <boost/test/unit_test.hpp>
+
+#include "dnsdist.hh"
+#include "dnsdist-lua-ffi.hh"
+#include "dolog.hh"
+
+uint16_t g_maxOutstanding{std::numeric_limits<uint16_t>::max()};
+
+std::mutex g_luamutex;
+#include "ext/luawrapper/include/LuaContext.hpp"
+LuaContext g_lua;
+
+bool g_snmpEnabled{false};
+bool g_snmpTrapsEnabled{false};
+DNSDistSNMPAgent* g_snmpAgent{nullptr};
+
+#if BENCH_POLICIES
+bool g_verbose{true};
+bool g_syslog{true};
+#include "dnsdist-rings.hh"
+Rings g_rings;
+GlobalStateHolder<NetmaskTree<DynBlock>> g_dynblockNMG;
+GlobalStateHolder<SuffixMatchTree<DynBlock>> g_dynblockSMT;
+#endif /* BENCH_POLICIES */
+
+/* add stub implementations, we don't want to include the corresponding object files
+ and their dependencies */
+
+#ifdef HAVE_DNS_OVER_HTTPS
+std::unordered_map<std::string, std::string> DOHUnit::getHTTPHeaders() const
+{
+ return {};
+}
+
+std::string DOHUnit::getHTTPPath() const
+{
+ return "";
+}
+
+std::string DOHUnit::getHTTPHost() const
+{
+ return "";
+}
+
+std::string DOHUnit::getHTTPScheme() const
+{
+ return "";
+}
+
+std::string DOHUnit::getHTTPQueryString() const
+{
+ return "";
+}
+
+void DOHUnit::setHTTPResponse(uint16_t statusCode, const std::string& body_, const std::string& contentType_)
+{
+}
+#endif /* HAVE_DNS_OVER_HTTPS */
+
+std::string DNSQuestion::getTrailingData() const
+{
+ return "";
+}
+
+bool DNSQuestion::setTrailingData(const std::string& tail)
+{
+ return false;
+}
+
+bool DNSDistSNMPAgent::sendDNSTrap(const DNSQuestion& dq, const std::string& reason)
+{
+ return false;
+}
+
+static DNSQuestion getDQ(const DNSName* providedName = nullptr)
+{
+ static const DNSName qname("powerdns.com.");
+ static const ComboAddress lc("127.0.0.1:53");
+ static const ComboAddress rem("192.0.2.1:42");
+ static struct timespec queryRealTime;
+ static struct dnsheader dh;
+
+ memset(&dh, 0, sizeof(dh));
+ uint16_t qtype = QType::A;
+ uint16_t qclass = QClass::IN;
+ size_t bufferSize = 0;
+ size_t queryLen = 0;
+ bool isTcp = false;
+ gettime(&queryRealTime, true);
+
+ DNSQuestion dq(providedName ? providedName : &qname, qtype, qclass, qname.wirelength(), &lc, &rem, &dh, bufferSize, queryLen, isTcp, &queryRealTime);
+ return dq;
+}
+
+static void benchPolicy(const ServerPolicy& pol)
+{
+#if BENCH_POLICIES
+ bool existingVerboseValue = g_verbose;
+ g_verbose = false;
+
+ std::vector<DNSName> names;
+ names.reserve(1000);
+ for (size_t idx = 0; idx < 1000; idx++) {
+ names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+ }
+ ServerPolicy::NumberedServerVector servers;
+ for (size_t idx = 1; idx <= 10; idx++) {
+ servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+ servers.at(idx - 1).second->setUp();
+ /* we need to have a weight of at least 1000 to get an optimal repartition with the consistent hashing algo */
+ servers.at(idx - 1).second->setWeight(1000);
+ /* make sure that the hashes have been computed */
+ servers.at(idx - 1).second->hash();
+ }
+
+ StopWatch sw;
+ sw.start();
+ for (size_t idx = 0; idx < 1000; idx++) {
+ for (const auto& name : names) {
+ auto dq = getDQ(&name);
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ }
+ }
+ cerr<<pol.name<<" took "<<std::to_string(sw.udiff())<<" us for "<<names.size()<<endl;
+
+ g_verbose = existingVerboseValue;
+#endif /* BENCH_POLICIES */
+}
+
+static void resetLuaContext()
+{
+ /* we need to reset this before cleaning the Lua state because the server policy might holds
+ a reference to a Lua function (Lua policies) */
+ g_policy.setState(ServerPolicy("leastOutstanding", leastOutstanding, false));
+ g_lua = LuaContext();
+}
+
+BOOST_AUTO_TEST_SUITE(dnsdistlbpolicies)
+
+BOOST_AUTO_TEST_CASE(test_firstAvailable) {
+ auto dq = getDQ();
+
+ ServerPolicy pol{"firstAvailable", firstAvailable, false};
+ ServerPolicy::NumberedServerVector servers;
+ servers.push_back({ 1, std::make_shared<DownstreamState>(ComboAddress("192.0.2.1:53")) });
+
+ /* servers start as 'down' */
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_CHECK(server == nullptr);
+
+ /* mark the server as 'up' */
+ servers.at(0).second->setUp();
+ server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_CHECK(server != nullptr);
+
+ /* add a second server, we should still get the first one */
+ servers.push_back({ 2, std::make_shared<DownstreamState>(ComboAddress("192.0.2.2:53")) });
+ server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(server != nullptr);
+ BOOST_CHECK(server == servers.at(0).second);
+
+ /* mark the first server as 'down', second as 'up' */
+ servers.at(0).second->setDown();
+ servers.at(1).second->setUp();
+ server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(server != nullptr);
+ BOOST_CHECK(server == servers.at(1).second);
+
+ std::vector<DNSName> names;
+ names.reserve(1000);
+ for (size_t idx = 0; idx < 1000; idx++) {
+ names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+ }
+ std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+ for (size_t idx = 1; idx <= 10; idx++) {
+ servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+ serversMap[servers.at(idx - 1).second] = 0;
+ servers.at(idx - 1).second->setUp();
+ }
+
+ benchPolicy(pol);
+}
+
+BOOST_AUTO_TEST_CASE(test_leastOutstanding) {
+ auto dq = getDQ();
+
+ ServerPolicy pol{"leastOutstanding", leastOutstanding, false};
+ ServerPolicy::NumberedServerVector servers;
+ servers.push_back({ 1, std::make_shared<DownstreamState>(ComboAddress("192.0.2.1:53")) });
+
+ /* servers start as 'down' */
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_CHECK(server == nullptr);
+
+ /* mark the server as 'up' */
+ servers.at(0).second->setUp();
+ server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_CHECK(server != nullptr);
+
+ /* add a second server, we should still get the first one */
+ servers.push_back({ 2, std::make_shared<DownstreamState>(ComboAddress("192.0.2.2:53")) });
+ server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(server != nullptr);
+ BOOST_CHECK(server == servers.at(0).second);
+
+ /* mark the first server as 'down', second as 'up' */
+ servers.at(0).second->setDown();
+ servers.at(1).second->setUp();
+ server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(server != nullptr);
+ BOOST_CHECK(server == servers.at(1).second);
+
+ /* mark both servers as 'up', increase the outstanding count of the first one */
+ servers.at(0).second->setUp();
+ servers.at(0).second->outstanding = 42;
+ servers.at(1).second->setUp();
+ server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(server != nullptr);
+ BOOST_CHECK(server == servers.at(1).second);
+
+ benchPolicy(pol);
+}
+
+BOOST_AUTO_TEST_CASE(test_wrandom) {
+ auto dq = getDQ();
+
+ ServerPolicy pol{"wrandom", wrandom, false};
+ ServerPolicy::NumberedServerVector servers;
+ std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+ for (size_t idx = 1; idx <= 10; idx++) {
+ servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+ serversMap[servers.at(idx - 1).second] = 0;
+ servers.at(idx - 1).second->setUp();
+ }
+
+ benchPolicy(pol);
+
+ for (size_t idx = 0; idx < 1000; idx++) {
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(serversMap.count(server) == 1);
+ ++serversMap[server];
+ }
+ uint64_t total = 0;
+ for (const auto& entry : serversMap) {
+ BOOST_CHECK_GT(entry.second, 0);
+ BOOST_CHECK_GT(entry.second, (1000 / servers.size() / 2));
+ BOOST_CHECK_LT(entry.second, (1000 / servers.size() * 2));
+ total += entry.second;
+ }
+ BOOST_CHECK_EQUAL(total, 1000);
+
+ /* reset */
+ for (auto& entry : serversMap) {
+ entry.second = 0;
+ BOOST_CHECK_EQUAL(entry.first->weight, 1);
+ }
+
+ /* reset */
+ for (auto& entry : serversMap) {
+ entry.second = 0;
+ BOOST_CHECK_EQUAL(entry.first->weight, 1);
+ }
+ /* change the weight of the last server to 100, default is 1 */
+ servers.at(servers.size()-1).second->weight = 100;
+
+ for (size_t idx = 0; idx < 1000; idx++) {
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(serversMap.count(server) == 1);
+ ++serversMap[server];
+ }
+
+ total = 0;
+ uint64_t totalW = 0;
+ for (const auto& entry : serversMap) {
+ total += entry.second;
+ totalW += entry.first->weight;
+ }
+ BOOST_CHECK_EQUAL(total, 1000);
+ auto last = servers.at(servers.size()-1).second;
+ const auto got = serversMap[last];
+ float expected = (1000 * 1.0 * last->weight) / totalW;
+ BOOST_CHECK_GT(got, expected / 2);
+ BOOST_CHECK_LT(got, expected * 2);
+}
+
+BOOST_AUTO_TEST_CASE(test_whashed) {
+ std::vector<DNSName> names;
+ names.reserve(1000);
+ for (size_t idx = 0; idx < 1000; idx++) {
+ names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+ }
+
+ ServerPolicy pol{"whashed", whashed, false};
+ ServerPolicy::NumberedServerVector servers;
+ std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+ for (size_t idx = 1; idx <= 10; idx++) {
+ servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+ serversMap[servers.at(idx - 1).second] = 0;
+ servers.at(idx - 1).second->setUp();
+ }
+
+ benchPolicy(pol);
+
+ for (const auto& name : names) {
+ auto dq = getDQ(&name);
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(serversMap.count(server) == 1);
+ ++serversMap[server];
+ }
+
+ uint64_t total = 0;
+ for (const auto& entry : serversMap) {
+ BOOST_CHECK_GT(entry.second, 0);
+ BOOST_CHECK_GT(entry.second, (names.size() / servers.size() / 2));
+ BOOST_CHECK_LT(entry.second, (names.size() / servers.size() * 2));
+ total += entry.second;
+ }
+ BOOST_CHECK_EQUAL(total, names.size());
+
+ /* reset */
+ for (auto& entry : serversMap) {
+ entry.second = 0;
+ BOOST_CHECK_EQUAL(entry.first->weight, 1);
+ }
+
+ /* request 1000 times the same name, we should go to the same server every time */
+ {
+ auto dq = getDQ(&names.at(0));
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ for (size_t idx = 0; idx < 1000; idx++) {
+ BOOST_CHECK(getSelectedBackendFromPolicy(pol, servers, dq) == server);
+ }
+ }
+
+ /* reset */
+ for (auto& entry : serversMap) {
+ entry.second = 0;
+ BOOST_CHECK_EQUAL(entry.first->weight, 1);
+ }
+ /* change the weight of the last server to 100, default is 1 */
+ servers.at(servers.size()-1).second->setWeight(100);
+
+ for (const auto& name : names) {
+ auto dq = getDQ(&name);
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(serversMap.count(server) == 1);
+ ++serversMap[server];
+ }
+
+ total = 0;
+ uint64_t totalW = 0;
+ for (const auto& entry : serversMap) {
+ total += entry.second;
+ totalW += entry.first->weight;
+ }
+ BOOST_CHECK_EQUAL(total, names.size());
+ auto last = servers.at(servers.size()-1).second;
+ const auto got = serversMap[last];
+ float expected = (names.size() * 1.0 * last->weight) / totalW;
+ BOOST_CHECK_GT(got, expected / 2);
+ BOOST_CHECK_LT(got, expected * 2);
+}
+
+BOOST_AUTO_TEST_CASE(test_chashed) {
+ bool existingVerboseValue = g_verbose;
+ g_verbose = false;
+
+ std::vector<DNSName> names;
+ names.reserve(1000);
+ for (size_t idx = 0; idx < 1000; idx++) {
+ names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+ }
+
+ ServerPolicy pol{"chashed", chashed, false};
+ ServerPolicy::NumberedServerVector servers;
+ std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+ for (size_t idx = 1; idx <= 10; idx++) {
+ servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+ serversMap[servers.at(idx - 1).second] = 0;
+ servers.at(idx - 1).second->setUp();
+ /* we need to have a weight of at least 1000 to get an optimal repartition with the consistent hashing algo */
+ servers.at(idx - 1).second->setWeight(1000);
+ /* make sure that the hashes have been computed */
+ servers.at(idx - 1).second->hash();
+ }
+
+ benchPolicy(pol);
+
+ for (const auto& name : names) {
+ auto dq = getDQ(&name);
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(serversMap.count(server) == 1);
+ ++serversMap[server];
+ }
+
+ uint64_t total = 0;
+ for (const auto& entry : serversMap) {
+ BOOST_CHECK_GT(entry.second, 0);
+ BOOST_CHECK_GT(entry.second, (names.size() / servers.size() / 2));
+ BOOST_CHECK_LT(entry.second, (names.size() / servers.size() * 2));
+ total += entry.second;
+ }
+ BOOST_CHECK_EQUAL(total, names.size());
+
+ /* reset */
+ for (auto& entry : serversMap) {
+ entry.second = 0;
+ BOOST_CHECK_EQUAL(entry.first->weight, 1000);
+ }
+
+ /* request 1000 times the same name, we should go to the same server every time */
+ {
+ auto dq = getDQ(&names.at(0));
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ for (size_t idx = 0; idx < 1000; idx++) {
+ BOOST_CHECK(getSelectedBackendFromPolicy(pol, servers, dq) == server);
+ }
+ }
+
+ /* reset */
+ for (auto& entry : serversMap) {
+ entry.second = 0;
+ BOOST_CHECK_EQUAL(entry.first->weight, 1000);
+ }
+ /* change the weight of the last server to 100000, others stay at 1000 */
+ servers.at(servers.size()-1).second->setWeight(100000);
+
+ for (const auto& name : names) {
+ auto dq = getDQ(&name);
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(serversMap.count(server) == 1);
+ ++serversMap[server];
+ }
+
+ total = 0;
+ uint64_t totalW = 0;
+ for (const auto& entry : serversMap) {
+ total += entry.second;
+ totalW += entry.first->weight;
+ }
+ BOOST_CHECK_EQUAL(total, names.size());
+ auto last = servers.at(servers.size()-1).second;
+ const auto got = serversMap[last];
+ float expected = (names.size() * 1.0 * last->weight) / totalW;
+ BOOST_CHECK_GT(got, expected / 2);
+ BOOST_CHECK_LT(got, expected * 2);
+
+ g_verbose = existingVerboseValue;
+}
+
+BOOST_AUTO_TEST_CASE(test_lua) {
+ std::vector<DNSName> names;
+ names.reserve(1000);
+ for (size_t idx = 0; idx < 1000; idx++) {
+ names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+ }
+
+ static const std::string policySetupStr = R"foo(
+ local counter = 0
+ function luaroundrobin(servers, dq)
+ counter = counter + 1
+ return servers[1 + (counter % #servers)]
+ end
+
+ setServerPolicyLua("luaroundrobin", luaroundrobin)
+ )foo";
+ resetLuaContext();
+ g_lua.writeFunction("setServerPolicyLua", [](string name, ServerPolicy::policyfunc_t policy) {
+ g_policy.setState(ServerPolicy{name, policy, true});
+ });
+ g_lua.executeCode(policySetupStr);
+
+ ServerPolicy pol = g_policy.getCopy();
+ ServerPolicy::NumberedServerVector servers;
+ std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+ for (size_t idx = 1; idx <= 10; idx++) {
+ servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+ serversMap[servers.at(idx - 1).second] = 0;
+ servers.at(idx - 1).second->setUp();
+ }
+ BOOST_REQUIRE_EQUAL(servers.size(), 10);
+
+ for (const auto& name : names) {
+ auto dq = getDQ(&name);
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(serversMap.count(server) == 1);
+ ++serversMap[server];
+ }
+
+ uint64_t total = 0;
+ for (const auto& entry : serversMap) {
+ BOOST_CHECK_GT(entry.second, 0);
+ BOOST_CHECK_GT(entry.second, (names.size() / servers.size() / 2));
+ BOOST_CHECK_LT(entry.second, (names.size() / servers.size() * 2));
+ total += entry.second;
+ }
+ BOOST_CHECK_EQUAL(total, names.size());
+
+ benchPolicy(pol);
+}
+
+#ifdef LUAJIT_VERSION
+
+BOOST_AUTO_TEST_CASE(test_lua_ffi_rr) {
+ std::vector<DNSName> names;
+ names.reserve(1000);
+ for (size_t idx = 0; idx < 1000; idx++) {
+ names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+ }
+
+ static const std::string policySetupStr = R"foo(
+ local ffi = require("ffi")
+ local C = ffi.C
+ local counter = 0
+ function ffilb(servers_list, dq)
+ local serversCount = tonumber(C.dnsdist_ffi_servers_list_get_count(servers_list))
+ counter = counter + 1
+ return counter % serversCount
+ end
+
+ setServerPolicyLuaFFI("FFI round-robin", ffilb)
+ )foo";
+ resetLuaContext();
+ g_lua.executeCode(getLuaFFIWrappers());
+ g_lua.writeFunction("setServerPolicyLuaFFI", [](string name, ServerPolicy::ffipolicyfunc_t policy) {
+ g_policy.setState(ServerPolicy(name, policy));
+ });
+ g_lua.executeCode(policySetupStr);
+
+ ServerPolicy pol = g_policy.getCopy();
+ ServerPolicy::NumberedServerVector servers;
+ std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+ for (size_t idx = 1; idx <= 10; idx++) {
+ servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+ serversMap[servers.at(idx - 1).second] = 0;
+ servers.at(idx - 1).second->setUp();
+ }
+ BOOST_REQUIRE_EQUAL(servers.size(), 10);
+
+ for (const auto& name : names) {
+ auto dq = getDQ(&name);
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(serversMap.count(server) == 1);
+ ++serversMap[server];
+ }
+
+ uint64_t total = 0;
+ for (const auto& entry : serversMap) {
+ BOOST_CHECK_GT(entry.second, 0);
+ BOOST_CHECK_GT(entry.second, (names.size() / servers.size() / 2));
+ BOOST_CHECK_LT(entry.second, (names.size() / servers.size() * 2));
+ total += entry.second;
+ }
+ BOOST_CHECK_EQUAL(total, names.size());
+
+ benchPolicy(pol);
+}
+
+BOOST_AUTO_TEST_CASE(test_lua_ffi_hashed) {
+ std::vector<DNSName> names;
+ names.reserve(1000);
+ for (size_t idx = 0; idx < 1000; idx++) {
+ names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+ }
+
+ static const std::string policySetupStr = R"foo(
+ local ffi = require("ffi")
+ local C = ffi.C
+ function ffilb(servers_list, dq)
+ local serversCount = tonumber(C.dnsdist_ffi_servers_list_get_count(servers_list))
+ local hash = tonumber(C.dnsdist_ffi_dnsquestion_get_qname_hash(dq, 0))
+ return hash % serversCount
+ end
+
+ setServerPolicyLuaFFI("FFI hashed", ffilb)
+ )foo";
+ resetLuaContext();
+ g_lua.executeCode(getLuaFFIWrappers());
+ g_lua.writeFunction("setServerPolicyLuaFFI", [](string name, ServerPolicy::ffipolicyfunc_t policy) {
+ g_policy.setState(ServerPolicy(name, policy));
+ });
+ g_lua.executeCode(policySetupStr);
+
+ ServerPolicy pol = g_policy.getCopy();
+ ServerPolicy::NumberedServerVector servers;
+ std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+ for (size_t idx = 1; idx <= 10; idx++) {
+ servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+ serversMap[servers.at(idx - 1).second] = 0;
+ servers.at(idx - 1).second->setUp();
+ }
+ BOOST_REQUIRE_EQUAL(servers.size(), 10);
+
+ for (const auto& name : names) {
+ auto dq = getDQ(&name);
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(serversMap.count(server) == 1);
+ ++serversMap[server];
+ }
+
+ uint64_t total = 0;
+ for (const auto& entry : serversMap) {
+ BOOST_CHECK_GT(entry.second, 0);
+ BOOST_CHECK_GT(entry.second, (names.size() / servers.size() / 2));
+ BOOST_CHECK_LT(entry.second, (names.size() / servers.size() * 2));
+ total += entry.second;
+ }
+ BOOST_CHECK_EQUAL(total, names.size());
+
+ benchPolicy(pol);
+}
+
+BOOST_AUTO_TEST_CASE(test_lua_ffi_whashed) {
+ std::vector<DNSName> names;
+ names.reserve(1000);
+ for (size_t idx = 0; idx < 1000; idx++) {
+ names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+ }
+
+ static const std::string policySetupStr = R"foo(
+ local ffi = require("ffi")
+ local C = ffi.C
+ function ffilb(servers_list, dq)
+ return tonumber(C.dnsdist_ffi_servers_list_whashed(servers_list, dq, C.dnsdist_ffi_dnsquestion_get_qname_hash(dq, 0)))
+ end
+
+ setServerPolicyLuaFFI("FFI whashed", ffilb)
+ )foo";
+ resetLuaContext();
+ g_lua.executeCode(getLuaFFIWrappers());
+ g_lua.writeFunction("setServerPolicyLuaFFI", [](string name, ServerPolicy::ffipolicyfunc_t policy) {
+ g_policy.setState(ServerPolicy(name, policy));
+ });
+ g_lua.executeCode(policySetupStr);
+
+ ServerPolicy pol = g_policy.getCopy();
+ ServerPolicy::NumberedServerVector servers;
+ std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+ for (size_t idx = 1; idx <= 10; idx++) {
+ servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+ serversMap[servers.at(idx - 1).second] = 0;
+ servers.at(idx - 1).second->setUp();
+ }
+ BOOST_REQUIRE_EQUAL(servers.size(), 10);
+
+ for (const auto& name : names) {
+ auto dq = getDQ(&name);
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(serversMap.count(server) == 1);
+ ++serversMap[server];
+ }
+
+ uint64_t total = 0;
+ for (const auto& entry : serversMap) {
+ BOOST_CHECK_GT(entry.second, 0);
+ BOOST_CHECK_GT(entry.second, (names.size() / servers.size() / 2));
+ BOOST_CHECK_LT(entry.second, (names.size() / servers.size() * 2));
+ total += entry.second;
+ }
+ BOOST_CHECK_EQUAL(total, names.size());
+
+ benchPolicy(pol);
+}
+
+BOOST_AUTO_TEST_CASE(test_lua_ffi_chashed) {
+ bool existingVerboseValue = g_verbose;
+ g_verbose = false;
+
+ std::vector<DNSName> names;
+ names.reserve(1000);
+ for (size_t idx = 0; idx < 1000; idx++) {
+ names.push_back(DNSName("powerdns-" + std::to_string(idx) + ".com."));
+ }
+
+ static const std::string policySetupStr = R"foo(
+ local ffi = require("ffi")
+ local C = ffi.C
+ function ffilb(servers_list, dq)
+ return tonumber(C.dnsdist_ffi_servers_list_chashed(servers_list, dq, C.dnsdist_ffi_dnsquestion_get_qname_hash(dq, 0)))
+ end
+
+ setServerPolicyLuaFFI("FFI chashed", ffilb)
+ )foo";
+ resetLuaContext();
+ g_lua.executeCode(getLuaFFIWrappers());
+ g_lua.writeFunction("setServerPolicyLuaFFI", [](string name, ServerPolicy::ffipolicyfunc_t policy) {
+ g_policy.setState(ServerPolicy(name, policy));
+ });
+ g_lua.executeCode(policySetupStr);
+
+ ServerPolicy pol = g_policy.getCopy();
+ ServerPolicy::NumberedServerVector servers;
+ std::map<std::shared_ptr<DownstreamState>, uint64_t> serversMap;
+ for (size_t idx = 1; idx <= 10; idx++) {
+ servers.push_back({ idx, std::make_shared<DownstreamState>(ComboAddress("192.0.2." + std::to_string(idx) + ":53")) });
+ serversMap[servers.at(idx - 1).second] = 0;
+ servers.at(idx - 1).second->setUp();
+ /* we need to have a weight of at least 1000 to get an optimal repartition with the consistent hashing algo */
+ servers.at(idx - 1).second->setWeight(1000);
+ /* make sure that the hashes have been computed */
+ servers.at(idx - 1).second->hash();
+ }
+ BOOST_REQUIRE_EQUAL(servers.size(), 10);
+
+ for (const auto& name : names) {
+ auto dq = getDQ(&name);
+ auto server = getSelectedBackendFromPolicy(pol, servers, dq);
+ BOOST_REQUIRE(serversMap.count(server) == 1);
+ ++serversMap[server];
+ }
+
+ uint64_t total = 0;
+ for (const auto& entry : serversMap) {
+ BOOST_CHECK_GT(entry.second, 0);
+ BOOST_CHECK_GT(entry.second, (names.size() / servers.size() / 2));
+ BOOST_CHECK_LT(entry.second, (names.size() / servers.size() * 2));
+ total += entry.second;
+ }
+ BOOST_CHECK_EQUAL(total, names.size());
+
+ benchPolicy(pol);
+
+ g_verbose = existingVerboseValue;
+}
+
+#endif /* LUAJIT_VERSION */
+
+BOOST_AUTO_TEST_SUITE_END()
sender = getattr(self, method)
(_, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertEquals(receivedResponse, expectedResponse)
+
+class TestAdvancedLuaRule(DNSDistTest):
+
+ _config_template = """
+
+ function luarulefunction(dq)
+ if dq.qname:toString() == 'lua-rule.advanced.tests.powerdns.com.' then
+ return true
+ end
+ return false
+ end
+
+ addAction(LuaRule(luarulefunction), RCodeAction(DNSRCode.NOTIMP))
+ addAction(AllRule(), RCodeAction(DNSRCode.REFUSED))
+ -- newServer{address="127.0.0.1:%s"}
+ """
+
+ def testAdvancedLuaRule(self):
+ """
+ Advanced: Test the LuaRule rule
+ """
+ name = 'lua-rule.advanced.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ # dnsdist set RA = RD for spoofed responses
+ query.flags &= ~dns.flags.RD
+ notimplResponse = dns.message.make_response(query)
+ notimplResponse.set_rcode(dns.rcode.NOTIMP)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (_, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertEquals(receivedResponse, notimplResponse)
+
+ name = 'not-lua-rule.advanced.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ # dnsdist set RA = RD for spoofed responses
+ query.flags &= ~dns.flags.RD
+ refusedResponse = dns.message.make_response(query)
+ refusedResponse.set_rcode(dns.rcode.REFUSED)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (_, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertEquals(receivedResponse, refusedResponse)
+
+class TestAdvancedLuaFFI(DNSDistTest):
+
+ _config_template = """
+ local ffi = require("ffi")
+
+ local expectingUDP = true
+
+ function luaffirulefunction(dq)
+ local qtype = ffi.C.dnsdist_ffi_dnsquestion_get_qtype(dq)
+ if qtype ~= DNSQType.A and qtype ~= DNSQType.SOA then
+ print('invalid qtype')
+ return false
+ end
+
+ local qclass = ffi.C.dnsdist_ffi_dnsquestion_get_qclass(dq)
+ if qclass ~= DNSClass.IN then
+ print('invalid qclass')
+ return false
+ end
+
+ local ret_ptr = ffi.new("char *[1]")
+ local ret_ptr_param = ffi.cast("const char **", ret_ptr)
+ local ret_size = ffi.new("size_t[1]")
+ local ret_size_param = ffi.cast("size_t*", ret_size)
+ ffi.C.dnsdist_ffi_dnsquestion_get_qname_raw(dq, ret_ptr_param, ret_size_param)
+ if ret_size[0] ~= 36 then
+ print('invalid length for the qname ')
+ print(ret_size[0])
+ return false
+ end
+
+ local expectedQname = string.char(6)..'luaffi'..string.char(8)..'advanced'..string.char(5)..'tests'..string.char(8)..'powerdns'..string.char(3)..'com'
+ if ffi.string(ret_ptr[0]) ~= expectedQname then
+ print('invalid qname')
+ print(ffi.string(ret_ptr[0]))
+ return false
+ end
+
+ local rcode = ffi.C.dnsdist_ffi_dnsquestion_get_rcode(dq)
+ if rcode ~= 0 then
+ print('invalid rcode')
+ return false
+ end
+
+ local opcode = ffi.C.dnsdist_ffi_dnsquestion_get_opcode(dq)
+ if qtype == DNSQType.A and opcode ~= DNSOpcode.Query then
+ print('invalid opcode')
+ return false
+ elseif qtype == DNSQType.SOA and opcode ~= DNSOpcode.Update then
+ print('invalid opcode')
+ return false
+ end
+
+ local tcp = ffi.C.dnsdist_ffi_dnsquestion_get_tcp(dq)
+ if expectingUDP == tcp then
+ print('invalid tcp')
+ return false
+ end
+ expectingUDP = expectingUDP == false
+
+ local dnssecok = ffi.C.dnsdist_ffi_dnsquestion_get_do(dq)
+ if dnssecok ~= false then
+ print('invalid DNSSEC OK')
+ return false
+ end
+
+ local len = ffi.C.dnsdist_ffi_dnsquestion_get_len(dq)
+ if len ~= 52 then
+ print('invalid length')
+ print(len)
+ return false
+ end
+
+ local tag = ffi.C.dnsdist_ffi_dnsquestion_get_tag(dq, 'a-tag')
+ if ffi.string(tag) ~= 'a-value' then
+ print('invalid tag value')
+ print(ffi.string(tag))
+ return false
+ end
+ return true
+ end
+
+ function luaffiactionfunction(dq)
+ local qtype = ffi.C.dnsdist_ffi_dnsquestion_get_qtype(dq)
+ if qtype == DNSQType.A then
+ local str = "192.0.2.1"
+ local buf = ffi.new("char[?]", #str + 1)
+ ffi.copy(buf, str)
+ ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+ return DNSAction.Spoof
+ elseif qtype == DNSQType.SOA then
+ ffi.C.dnsdist_ffi_dnsquestion_set_rcode(dq, DNSRCode.REFUSED)
+ return DNSAction.Refused
+ end
+ end
+
+ function luaffiactionsettag(dq)
+ ffi.C.dnsdist_ffi_dnsquestion_set_tag(dq, 'a-tag', 'a-value')
+ return DNSAction.None
+ end
+
+ addAction(AllRule(), LuaFFIAction(luaffiactionsettag))
+ addAction(LuaFFIRule(luaffirulefunction), LuaFFIAction(luaffiactionfunction))
+ -- newServer{address="127.0.0.1:%s"}
+ """
+
+ def testAdvancedLuaFFI(self):
+ """
+ Advanced: Test the Lua FFI interface
+ """
+ name = 'luaffi.advanced.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ # dnsdist set RA = RD for spoofed responses
+ query.flags &= ~dns.flags.RD
+
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 60,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '192.0.2.1')
+ response.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (_, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertEquals(receivedResponse, response)
+
+ def testAdvancedLuaFFIUpdate(self):
+ """
+ Advanced: Test the Lua FFI interface via an update
+ """
+ name = 'luaffi.advanced.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'SOA', 'IN')
+ query.set_opcode(dns.opcode.UPDATE)
+ # dnsdist set RA = RD for spoofed responses
+ query.flags &= ~dns.flags.RD
+
+ response = dns.message.make_response(query)
+ response.set_rcode(dns.rcode.REFUSED)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (_, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertEquals(receivedResponse, response)
import pycurl
from io import BytesIO
-#from hyper import HTTP20Connection
-#from hyper.ssl_compat import SSLContext, PROTOCOL_TLSv1_2
@unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled')
class DNSDistDOHTest(DNSDistTest):
print("Launching tests..")
-# @classmethod
-# def openDOHConnection(cls, port, caFile, timeout=2.0):
-# sslctx = SSLContext(PROTOCOL_TLSv1_2)
-# sslctx.load_verify_locations(caFile)
-# return HTTP20Connection('127.0.0.1', port=port, secure=True, timeout=timeout, ssl_context=sslctx, force_proto='h2')
-
-# @classmethod
-# def sendDOHQueryOverConnection(cls, conn, baseurl, query, response=None, timeout=2.0):
-# url = cls.getDOHGetURL(baseurl, query)
-
-# if response:
-# cls._toResponderQueue.put(response, True, timeout)
-
-# conn.request('GET', url)
-
-# @classmethod
-# def recvDOHResponseOverConnection(cls, conn, useQueue=False, timeout=2.0):
-# message = None
-# data = conn.get_response()
-# if data:
-# data = data.read()
-# if data:
-# message = dns.message.from_wire(data)
-
-# if useQueue and not cls._fromResponderQueue.empty():
-# receivedQuery = cls._fromResponderQueue.get(True, timeout)
-# return (receivedQuery, message)
-# else:
-# return message
-
class TestDOH(DNSDistDOHTest):
_serverKey = 'server.key'
self.checkNoHeader('cache-control')
self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
self.assertEquals(response, receivedResponse)
+
+class TestDOHFFI(DNSDistDOHTest):
+
+ _serverKey = 'server.key'
+ _serverCert = 'server.chain'
+ _serverName = 'tls.tests.dnsdist.org'
+ _caCert = 'ca.pem'
+ _dohServerPort = 8443
+ _customResponseHeader1 = 'access-control-allow-origin: *'
+ _customResponseHeader2 = 'user-agent: derp'
+ _dohBaseURL = ("https://%s:%d/" % (_serverName, _dohServerPort))
+ _config_template = """
+ newServer{address="127.0.0.1:%s"}
+
+ addDOHLocal("127.0.0.1:%s", "%s", "%s", { "/" }, {customResponseHeaders={["access-control-allow-origin"]="*",["user-agent"]="derp",["UPPERCASE"]="VaLuE"}})
+
+ local ffi = require("ffi")
+
+ function dohHandler(dq)
+ local scheme = ffi.string(ffi.C.dnsdist_ffi_dnsquestion_get_http_scheme(dq))
+ local host = ffi.string(ffi.C.dnsdist_ffi_dnsquestion_get_http_host(dq))
+ local path = ffi.string(ffi.C.dnsdist_ffi_dnsquestion_get_http_path(dq))
+ local query_string = ffi.string(ffi.C.dnsdist_ffi_dnsquestion_get_http_query_string(dq))
+ if scheme == 'https' and host == '%s:%d' and path == '/' and query_string == '' then
+ local foundct = false
+ local headers_ptr = ffi.new("const dnsdist_ffi_http_header_t *[1]")
+ local headers_ptr_param = ffi.cast("const dnsdist_ffi_http_header_t **", headers_ptr)
+
+ local headers_count = tonumber(ffi.C.dnsdist_ffi_dnsquestion_get_http_headers(dq, headers_ptr_param))
+ if headers_count > 0 then
+ for idx = 0, headers_count-1 do
+ if ffi.string(headers_ptr[0][idx].name) == 'content-type' and ffi.string(headers_ptr[0][idx].value) == 'application/dns-message' then
+ foundct = true
+ break
+ end
+ end
+ end
+ if foundct then
+ ffi.C.dnsdist_ffi_dnsquestion_set_http_response(dq, 200, 'It works!', 'text/plain')
+ return DNSAction.HeaderModify
+ end
+ end
+ return DNSAction.None
+ end
+ addAction("http-lua-ffi.doh.tests.powerdns.com.", LuaFFIAction(dohHandler))
+ """
+ _config_params = ['_testServerPort', '_dohServerPort', '_serverCert', '_serverKey', '_serverName', '_dohServerPort']
+
+ def testHTTPLuaFFIResponse(self):
+ """
+ DOH: Lua FFI HTTP Response
+ """
+ name = 'http-lua-ffi.doh.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ query.id = 0
+
+ (_, receivedResponse) = self.sendDOHPostQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, caFile=self._caCert, useQueue=False, rawResponse=True)
+ self.assertTrue(receivedResponse)
+ self.assertEquals(receivedResponse, b'It works!')
+ self.assertEquals(self._rcode, 200)
+ self.assertTrue('content-type: text/plain' in self._response_headers.decode())
+
receivedQuery.id = query.id
self.assertEquals(receivedQuery, query)
self.assertEquals(receivedResponse, response)
+
+class TestEDNSOptionsLuaFFI(DNSDistTest):
+
+ _config_template = """
+ local ffi = require("ffi")
+
+ function testEDNSOptions(dq)
+ local options_ptr = ffi.new("const dnsdist_ffi_ednsoption_t *[1]")
+ local ret_ptr_param = ffi.cast("const dnsdist_ffi_ednsoption_t **", options_ptr)
+
+ local options_count = tonumber(ffi.C.dnsdist_ffi_dnsquestion_get_edns_options(dq, ret_ptr_param))
+
+ local qname_ptr = ffi.new("const char *[1]")
+ local qname_ptr_param = ffi.cast("const char **", qname_ptr)
+ local qname_size = ffi.new("size_t[1]")
+ local qname_size_param = ffi.cast("size_t*", qname_size)
+ ffi.C.dnsdist_ffi_dnsquestion_get_qname_raw(dq, qname_ptr_param, qname_size_param)
+ local qname = ffi.string(qname_ptr[0])
+
+ if string.match(qname, 'noedns') then
+ if options_count ~= 0 then
+ local str = "192.0.2.255"
+ local buf = ffi.new("char[?]", #str + 1)
+ ffi.copy(buf, str)
+ ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+ return DNSAction.Spoof
+ end
+ end
+
+ local cookies_count = 0
+ local ecs_count = 0
+ local ecs_index = -1
+ local first_cookie_index = -1
+ local last_cookie_index = -1
+ if options_count > 0 then
+ for idx = 0, options_count-1 do
+ if options_ptr[0][idx].optionCode == EDNSOptionCode.COOKIE then
+ cookies_count = cookies_count + 1
+ if first_cookie_index == -1 then
+ first_cookie_index = idx
+ end
+ last_cookie_index = idx
+ elseif options_ptr[0][idx].optionCode == EDNSOptionCode.ECS then
+ ecs_count = ecs_count + 1
+ ecs_index = idx
+ end
+ end
+ end
+
+ if string.match(qname, 'multiplecookies') then
+ if cookies_count == 0 then
+ local str = "192.0.2.1"
+ local buf = ffi.new("char[?]", #str + 1)
+ ffi.copy(buf, str)
+ ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+ return DNSAction.Spoof
+ end
+ if cookies_count ~= 2 then
+ local str = "192.0.2.2"
+ local buf = ffi.new("char[?]", #str + 1)
+ ffi.copy(buf, str)
+ ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+ return DNSAction.Spoof
+ end
+ if options_ptr[0][first_cookie_index].len ~= 16 then
+ local str = "192.0.2.3"
+ local buf = ffi.new("char[?]", #str + 1)
+ ffi.copy(buf, str)
+ ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+ return DNSAction.Spoof
+ end
+ if options_ptr[0][last_cookie_index].len ~= 16 then
+ local str = "192.0.2.4"
+ local buf = ffi.new("char[?]", #str + 1)
+ ffi.copy(buf, str)
+ ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+ return DNSAction.Spoof
+ end
+ elseif string.match(qname, 'cookie') then
+
+ if cookies_count == 0 then
+ local str = "192.0.2.1"
+ local buf = ffi.new("char[?]", #str + 1)
+ ffi.copy(buf, str)
+ ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+ return DNSAction.Spoof
+ end
+ if cookies_count ~= 1 or options_ptr[0][first_cookie_index].len ~= 16 then
+ local str = "192.0.2.2"
+ local buf = ffi.new("char[?]", #str + 1)
+ ffi.copy(buf, str)
+ ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+ return DNSAction.Spoof
+ end
+ end
+
+ if string.match(qname, 'ecs4') then
+ if ecs_count == 0 then
+ local str = "192.0.2.51"
+ local buf = ffi.new("char[?]", #str + 1)
+ ffi.copy(buf, str)
+ ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+ return DNSAction.Spoof
+ end
+
+ if ecs_count ~= 1 or options_ptr[0][ecs_index].len ~= 8 then
+ local str = "192.0.2.52"
+ local buf = ffi.new("char[?]", #str + 1)
+ ffi.copy(buf, str)
+ ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+ return DNSAction.Spoof
+ end
+ end
+
+ if string.match(qname, 'ecs6') then
+ if ecs_count == 0 then
+ local str = "192.0.2.101"
+ local buf = ffi.new("char[?]", #str + 1)
+ ffi.copy(buf, str)
+ ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+ return DNSAction.Spoof
+ end
+ if ecs_count ~= 1 or options_ptr[0][ecs_index].len ~= 20 then
+ local str = "192.0.2.102"
+ local buf = ffi.new("char[?]", #str + 1)
+ ffi.copy(buf, str)
+ ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+ return DNSAction.Spoof
+ end
+ end
+
+ return DNSAction.None
+
+ end
+
+ addAction(AllRule(), LuaFFIAction(testEDNSOptions))
+
+ newServer{address="127.0.0.1:%s"}
+ """
+
+ def testWithoutEDNSFFI(self):
+ """
+ EDNS Options: No EDNS (FFI)
+ """
+ name = 'noedns.ednsoptions.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '192.0.2.255')
+ response.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)
+
+ def testCookieFFI(self):
+ """
+ EDNS Options: Cookie (FFI)
+ """
+ name = 'cookie.ednsoptions.tests.powerdns.com.'
+ eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco])
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)
+
+ def testECS4FFI(self):
+ """
+ EDNS Options: ECS4 (FFI)
+ """
+ name = 'ecs4.ednsoptions.tests.powerdns.com.'
+ ecso = clientsubnetoption.ClientSubnetOption('1.2.3.4', 32)
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)
+
+ def testECS6FFI(self):
+ """
+ EDNS Options: ECS6 (FFI)
+ """
+ name = 'ecs6.ednsoptions.tests.powerdns.com.'
+ ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)
+
+ def testECS6CookieFFI(self):
+ """
+ EDNS Options: Cookie + ECS6 (FFI)
+ """
+ name = 'cookie-ecs6.ednsoptions.tests.powerdns.com.'
+ eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
+ ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso,eco])
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)
+
+ def testMultiCookiesECS6FFI(self):
+ """
+ EDNS Options: Two Cookies + ECS6 (FFI)
+ """
+ name = 'multiplecookies-ecs6.ednsoptions.tests.powerdns.com.'
+ eco1 = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
+ ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
+ eco2 = cookiesoption.CookiesOption(b'deadc0de', b'deadc0de')
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco1, ecso, eco2])
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(receivedQuery, query)
+ self.assertEquals(receivedResponse, response)