#include <atomic>
#include "arguments.hh"
#include "dolog.hh"
+#include <readline/readline.h>
+#include <readline/history.h>
+#include "dnsname.hh"
#include <fstream>
+
#undef L
atomic<uint64_t> age;
};
-struct DownstreamState
+
+struct DownstreamState
{
+ DownstreamState(const ComboAddress& remote_);
+
int fd;
thread tid;
ComboAddress remote;
atomic<uint64_t> outstanding{0};
atomic<uint64_t> reuseds{0};
atomic<uint64_t> queries{0};
+ struct {
+ atomic<uint64_t> idOffset{0};
+ atomic<uint64_t> sendErrors{0};
+ atomic<uint64_t> outstanding{0};
+ atomic<uint64_t> reuseds{0};
+ atomic<uint64_t> queries{0};
+ } prev;
};
-DownstreamState* g_dstates;
-unsigned int g_numdownstreams;
+vector<std::shared_ptr<DownstreamState> > g_dstates;
// listens on a dedicated socket, lobs answers from downstream servers to original requestors
-void* responderThread(DownstreamState* state)
+void* responderThread(std::shared_ptr<DownstreamState> state)
{
char packet[4096];
return 0;
}
+DownstreamState::DownstreamState(const ComboAddress& remote_)
+{
+ remote = remote_;
+
+ fd = SSocket(remote.sin4.sin_family, SOCK_DGRAM, 0);
+ SConnect(fd, remote);
+
+ idStates.resize(g_maxOutstanding);
+
+ warnlog("Added downstream server %s", remote.toStringWithPort());
+}
+
+
struct ClientState
{
ComboAddress local;
int tcpFD;
};
-#if 0
-DownstreamState& getBestDownstream()
-{
- unsigned int lowest = std::numeric_limits<unsigned int>::max();
- unsigned int chosen = 0;
- for(unsigned int n = 0; n < g_numdownstreams; ++n) {
- if(g_dstates[n].outstanding < lowest) {
- chosen = n;
- lowest=g_dstates[n].outstanding;
- }
- }
- return g_dstates[chosen];
-}
-#endif
LuaContext g_lua;
-class Object {
-public:
- Object() : value(10) {}
-
- void increment() { std::cout << "incrementing" << std::endl; value++; }
-
- int value;
-};
-
-DownstreamState& getBestDownstream()
+DownstreamState& getBestDownstream(const ComboAddress& remote, const DNSName& qname, uint16_t qtype)
{
- auto pickServer=g_lua.readVariable<LuaContext::LuaFunctionCaller<int (void)> >("pickServer");
- auto i = pickServer();
- return g_dstates[i];
-
-
-
+ //auto pickServer=g_lua.readVariable<LuaContext::LuaFunctionCaller<std::shared_ptr<DownstreamState> (void)> >("pickServer");
+ auto pickServer=g_lua.readVariable<std::function<std::shared_ptr<DownstreamState>(ComboAddress, DNSName, uint16_t)> >("pickServer");
+ return *pickServer(remote, DNSName(qname), qtype);
}
static void daemonize(void)
if(g_vm.count("regex-drop"))
re=new Regex(g_vm["regex-drop"].as<string>());
+ auto blockFilter=g_lua.readVariable<std::function<bool(ComboAddress, DNSName, uint16_t)> >("blockFilter");
+
for(;;) {
len = recvfrom(cs->udpFD, packet, sizeof(packet), 0, (struct sockaddr*) &remote, &socklen);
if(len < (int)sizeof(struct dnsheader))
continue;
- if(re) {
- qname=questionExpand(packet, len, qtype);
- if(re->match(qname)) {
- g_regexBlocks++;
- continue;
- }
+
+ DNSName qname(packet+12, len-12, &qtype);
+ if(blockFilter(remote, qname, qtype))
+ continue;
+ if(re && re->match(qname.toString())) {
+ g_regexBlocks++;
+ continue;
}
+
- DownstreamState& ss = getBestDownstream();
+ DownstreamState& ss = getBestDownstream(remote, qname, qtype);
ss.queries++;
unsigned int idOffset = (ss.idOffset++) % g_maxOutstanding;
Let's start naively.
*/
-int getTCPDownstream(DownstreamState** ds)
+int getTCPDownstream(DownstreamState** ds, const ComboAddress& remote, const std::string& qname, uint16_t qtype)
{
- *ds = &getBestDownstream();
+ *ds = &getBestDownstream(remote, qname, qtype);
vinfolog("TCP connecting to downstream %s", (*ds)->remote.toStringWithPort());
int sock = SSocket((*ds)->remote.sin4.sin_family, SOCK_STREAM, 0);
for(;;) {
ConnectionInfo* citmp, ci;
+ ComboAddress fixme;
readn2(pipefd, &citmp, sizeof(citmp));
--g_tcpclientthreads.d_queued;
ci=*citmp;
delete citmp;
- if(dsock == -1)
- dsock = getTCPDownstream(&ds);
+ if(dsock == -1) {
+
+ dsock = getTCPDownstream(&ds, fixme, "", 0);
+ }
else {
vinfolog("Reusing existing TCP connection to %s", ds->remote.toStringWithPort());
}
if(!putMsgLen(dsock, qlen)) {
vinfolog("Downstream connection to %s died on us, getting a new one!", ds->remote.toStringWithPort());
close(dsock);
- dsock=getTCPDownstream(&ds);
+ dsock=getTCPDownstream(&ds, fixme, "", 0);
goto retry;
}
if(!getMsgLen(dsock, &rlen)) {
vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->remote.toStringWithPort());
close(dsock);
- dsock=getTCPDownstream(&ds);
+ dsock=getTCPDownstream(&ds, fixme, "", 0);
goto retry;
}
}
-void* statThread()
+void* maintThread()
{
int interval = 1;
if(!interval)
return 0;
uint32_t lastQueries=0;
- uint64_t pqueries[g_numdownstreams];
-
- for(unsigned int n=0; n < g_numdownstreams; ++n)
- pqueries[n] = g_dstates[n].queries.load();
-
for(;;) {
sleep(interval);
unsigned int outstanding=0;
uint64_t numQueries=0;
- for(unsigned int n=0; n < g_numdownstreams; ++n) {
- DownstreamState& dss = g_dstates[n];
-
- vinfolog(" %s: %d outstanding, %f qps", dss.remote.toStringWithPort(), dss.outstanding.load(), ((dss.queries.load() - pqueries[n])/interval));
+ for(auto& dss : g_dstates) {
+ vinfolog(" %s: %d outstanding, %f qps", dss->remote.toStringWithPort(), dss->outstanding.load(), ((dss->queries.load() - dss->prev.queries.load())/interval));
- outstanding += dss.outstanding;
- pqueries[n]=dss.queries.load();
- numQueries += dss.queries;
- for(unsigned int i=0 ; i < g_maxOutstanding; ++i) {
- IDState& ids = dss.idStates[i];
+ outstanding += dss->outstanding;
+ dss->prev.queries.store(dss->queries.load());
+ numQueries += dss->queries;
+
+ for(IDState& ids : dss->idStates) {
if(ids.origFD >=0 && ids.age++ > 2) {
ids.age = 0;
ids.origFD = -1;
- dss.reuseds++;
- --dss.outstanding;
+ dss->reuseds++;
+ --dss->outstanding;
}
}
}
}
-struct Server
-{
- string d_name;
- ComboAddress d_address;
-};
-vector<Server> g_servers;
-int defineServer(const std::string& name, const std::string& address)
-{
- g_servers.push_back({name, ComboAddress(address, 53)});
- return g_servers.size()-1;
-}
void setupLua()
{
- g_lua.writeVariable("defineServer", &defineServer);
- std::ifstream ifs("dnsdistconf.lua");
- g_lua.executeCode(ifs);
+ g_lua.writeFunction("newServer",
+ [](const std::string& address)
+ {
+ auto ret=std::shared_ptr<DownstreamState>(new DownstreamState(ComboAddress(address, 53)));
+ ret->tid = move(thread(responderThread, ret));
+ g_dstates.push_back(ret);
+ return ret;
+ } );
+
+ g_lua.writeFunction("deleteServer",
+ [](std::shared_ptr<DownstreamState> rem)
+ {
+ g_dstates.erase(remove(g_dstates.begin(), g_dstates.end(), rem), g_dstates.end());
+ } );
+
+
+ g_lua.writeFunction("listServers", []() {
+ string ret;
+ for(auto& s : g_dstates) {
+ if(!ret.empty()) ret+="\n";
+ ret+=s->remote.toStringWithPort() + " " + std::to_string(s->queries.load()) + " " + std::to_string(s->outstanding.load());
+ }
+ return ret;
+ });
- Object o1, o2;
- g_lua.registerFunction("increment", &Object::increment);
- g_lua.writeVariable("obj1", o1);
- g_lua.writeVariable("obj2", o2);
- g_lua.executeCode("obj1:increment();");
- g_lua.executeCode("obj1:increment();");
+ g_lua.writeFunction("getServers", []() {
+ vector<pair<int, std::shared_ptr<DownstreamState> > > ret;
+ int count=1;
+ for(auto& s : g_dstates) {
+ ret.push_back(make_pair(count++, s));
+ }
+ return ret;
+ });
+
+ g_lua.registerFunction<string(DownstreamState::*)()>("tostring", [](const DownstreamState& s) { return s.remote.toStringWithPort(); });
+
+ std::ifstream ifs("dnsdistconf.lua");
+ g_lua.registerFunction("tostring", &ComboAddress::toString);
- std::cout << g_lua.readVariable<Object>("obj1").value << std::endl;
- std::cout << g_lua.readVariable<Object>("obj2").value << std::endl;
+ g_lua.registerFunction("isPartOf", &DNSName::isPartOf);
+ g_lua.registerFunction("tostring", &DNSName::toString);
+ g_lua.writeFunction("newDNSName", [](const std::string& name) { return DNSName(name); });
+ g_lua.executeCode(ifs);
}
+
int main(int argc, char** argv)
try
{
-
- setupLua();
signal(SIGPIPE, SIG_IGN);
openlog("dnsdist", LOG_PID, LOG_DAEMON);
g_console=true;
g_verbose=g_vm.count("verbose");
g_maxOutstanding = g_vm["max-outstanding"].as<uint16_t>();
- /*
- if(!g_vm.count("remotes")) {
- cerr<<"Need to specify at least one remote address"<<endl;
- cout<<desc<<endl;
- exit(EXIT_FAILURE);
+ setupLua();
+
+ if(g_vm.count("remotes")) {
+ for(const auto& address : g_vm["remotes"].as<vector<string>>()) {
+ auto ret=std::shared_ptr<DownstreamState>(new DownstreamState(ComboAddress(address, 53)));
+ ret->tid = move(thread(responderThread, ret));
+ g_dstates.push_back(ret);
+ }
}
- */
if(g_vm["daemon"].as<bool>()) {
g_console=false;
vinfolog("Running in the foreground");
}
- // vector<string> remotes = g_vm["remotes"].as<vector<string> >();
-
- g_numdownstreams = g_servers.size();
- g_dstates = new DownstreamState[g_numdownstreams];
- int pos=0;
- for(const Server& server : g_servers) {
- DownstreamState& dss = g_dstates[pos++];
-
- dss.remote = server.d_address;
-
- dss.fd = SSocket(dss.remote.sin4.sin_family, SOCK_DGRAM, 0);
- SConnect(dss.fd, dss.remote);
-
- dss.idStates.resize(g_maxOutstanding);
-
- infolog("Added downstream server %s", dss.remote.toStringWithPort());
-
- dss.tid = move(thread(responderThread, &dss));
- }
-
vector<string> locals;
if(g_vm.count("local"))
locals = g_vm["local"].as<vector<string> >();
locals.push_back("::");
for(const string& local : locals) {
- cerr<<local<<endl;
ClientState* cs = new ClientState;
cs->local= ComboAddress(local, 53);
cs->udpFD = SSocket(cs->local.sin4.sin_family, SOCK_DGRAM, 0);
t1.detach();
}
- thread stattid(statThread);
- stattid.join();
+ thread stattid(maintThread);
+ stattid.detach();
+
+ for(;;) {
+ char* sline = readline("> ");
+ if(!sline)
+ break;
+
+ string line(sline);
+ if(!line.empty())
+ add_history(sline);
+
+ free(sline);
+
+
+ if(line=="quit")
+ break;
+
+ try {
+ g_lua.executeCode(line);
+ }
+ catch(std::exception& e) {
+ cerr<<"Error: "<<e.what()<<endl;
+ }
+
+ }
+
+ // stattid.join();
}
-/*
+
catch(std::exception &e)
{
errlog("Fatal: %s", e.what());
}
-*/
+
catch(PDNSException &ae)
{
errlog("Fatal: %s", ae.reason);
--- /dev/null
+#include "dnsname.hh"
+#include <boost/format.hpp>
+#include <string>
+#include "dnswriter.hh"
+#include "misc.hh"
+
+DNSName::DNSName(const char* p)
+{
+ auto vec = segmentDNSName(p);
+ for(auto& e : vec)
+ d_labels.push_back(e);
+}
+
+DNSName::DNSName(const char* pos, int len, uint16_t* qtype)
+{
+ unsigned char labellen;
+ const char* end = pos + len;
+ while((labellen=*pos++) && pos < end) { // "scan and copy"
+ d_labels.push_back(string(pos, labellen));
+ pos+=labellen;
+ }
+ if(qtype && pos + labellen + 2 <= end)
+ *qtype=(*(const unsigned char*)pos)*256 + *((const unsigned char*)pos+1);
+
+}
+
+std::string DNSName::toString() const
+{
+ if(d_labels.empty()) // I keep wondering if there is some deeper meaning to the need to do this
+ return ".";
+ std::string ret;
+ for(const auto& s : d_labels) {
+ ret+= escapeLabel(s) + ".";
+ }
+ return ret;
+}
+
+std::string DNSName::toDNSString() const
+{
+ std::string ret;
+ for(const auto& s : d_labels) {
+ ret.append(1, (char) s.length());
+ ret.append(s);
+ }
+ ret.append(1, (char)0);
+ return ret;
+}
+
+
+bool DNSName::isPartOf(const DNSName& parent) const
+{
+ auto us = d_labels.crbegin();
+ auto p = parent.d_labels.crbegin();
+ for(; us != d_labels.crend() && p != parent.d_labels.crend(); ++us, ++p) {
+ if(!pdns_iequals(*p, *us))
+ break;
+ }
+ return (p==parent.d_labels.crend());
+}
+
+void DNSName::appendRawLabel(const std::string& label)
+{
+ d_labels.push_back(label);
+}
+
+void DNSName::prependRawLabel(const std::string& label)
+{
+ d_labels.push_front(label);
+}
+
+deque<string> DNSName::getRawLabels() const
+{
+ return d_labels;
+}
+
+bool DNSName::chopOff()
+{
+ if(d_labels.empty())
+ return false;
+ d_labels.pop_front();
+ return true;
+}
+
+bool DNSName::operator==(const DNSName& rhs) const
+{
+ if(rhs.d_labels.size() != d_labels.size())
+ return false;
+
+ auto us = d_labels.crbegin();
+ auto p = rhs.d_labels.crbegin();
+ for(; us != d_labels.crend() && p != rhs.d_labels.crend(); ++us, ++p) {
+ if(!pdns_iequals(*p, *us))
+ return false;
+ }
+ return true;
+}
+
+string DNSName::escapeLabel(const std::string& label)
+{
+ string ret;
+ for(uint8_t p : label) {
+ if(p=='.')
+ ret+="\\.";
+ else if(p=='\\')
+ ret+="\\\\";
+ else if(p > 0x21 && p < 0x7e)
+ ret.append(1, (char)p);
+ else {
+ ret+="\\" + (boost::format("%03o") % (unsigned int)p).str();
+ }
+ }
+ return ret;
+}
--- /dev/null
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+#include <boost/test/unit_test.hpp>
+#include <boost/assign/std/map.hpp>
+#include "dnsname.hh"
+#include "misc.hh"
+#include "dnswriter.hh"
+#include "dnsrecords.hh"
+using namespace boost;
+using std::string;
+
+BOOST_AUTO_TEST_SUITE(dnsname_cc)
+
+BOOST_AUTO_TEST_CASE(test_basic) {
+ string before("www.ds9a.nl.");
+ DNSName b(before);
+
+ BOOST_CHECK_EQUAL(b.getRawLabels().size(), 3);
+ string after(b.toString());
+ BOOST_CHECK_EQUAL(before, after);
+
+ DNSName wwwds9anl("www.ds9a.nl.");
+ DNSName nl("nl.");
+ BOOST_CHECK(wwwds9anl.isPartOf(nl));
+ BOOST_CHECK(wwwds9anl.isPartOf(wwwds9anl));
+
+ BOOST_CHECK(!nl.isPartOf(wwwds9anl));
+
+ BOOST_CHECK(wwwds9anl == wwwds9anl);
+
+ BOOST_CHECK(DNSName("wWw.ds9A.Nl.") == DNSName("www.ds9a.nl."));
+ BOOST_CHECK(DNSName("www.ds9a.nl.") == DNSName("www.ds9a.nl."));
+
+ BOOST_CHECK(DNSName("www.ds9a.nl.").toString() == "www.ds9a.nl.");
+
+ DNSName left("ds9a.nl.");
+ left.prependRawLabel("www");
+ BOOST_CHECK( left == DNSName("WwW.Ds9A.Nl."));
+
+ left.appendRawLabel("com");
+
+ BOOST_CHECK( left == DNSName("WwW.Ds9A.Nl.com."));
+
+ DNSName root;
+ BOOST_CHECK(root.toString() == ".");
+
+ root.appendRawLabel("www");
+ root.appendRawLabel("powerdns.com");
+ root.appendRawLabel("com");
+
+ BOOST_CHECK_EQUAL(root.toString(), "www.powerdns\\.com.com.");
+
+ DNSName rfc4343_2_2(R"(Donald\032E\.\032Eastlake\0323rd.example.)");
+ DNSName example("example.");
+ BOOST_CHECK(rfc4343_2_2.isPartOf(example));
+
+ auto labels=rfc4343_2_2.getRawLabels();
+ BOOST_CHECK_EQUAL(*labels.begin(), "Donald E. Eastlake 3rd");
+ BOOST_CHECK_EQUAL(*labels.rbegin(), "example");
+ BOOST_CHECK_EQUAL(labels.size(), 2);
+
+ try {
+ DNSName broken("bert..hubert.");
+ BOOST_CHECK(0);
+ }catch(...){}
+
+ DNSName n;
+ n.appendRawLabel("powerdns.dnsmaster");
+ n.appendRawLabel("powerdns");
+ n.appendRawLabel("com");
+
+ BOOST_CHECK_EQUAL(n.toString(), "powerdns\\.dnsmaster.powerdns.com.");
+
+ BOOST_CHECK_EQUAL(DNSName().toString(), ".");
+
+ DNSName p;
+ string label("power");
+ label.append(1, (char)0);
+ label.append("dns");
+ p.appendRawLabel(label);
+ p.appendRawLabel("com");
+
+ BOOST_CHECK_EQUAL(p.toString(), "power\\000dns.com.");
+}
+
+
+BOOST_AUTO_TEST_CASE(test_dnsstrings) {
+ DNSName w("www.powerdns.com.");
+ BOOST_CHECK_EQUAL(w.toDNSString(), string("\003www\010powerdns\003com\000", 18));
+}
+
+BOOST_AUTO_TEST_CASE(test_chopping) {
+ DNSName w("www.powerdns.com.");
+ BOOST_CHECK_EQUAL(w.toString(), "www.powerdns.com.");
+ BOOST_CHECK(w.chopOff());
+ BOOST_CHECK_EQUAL(w.toString(), "powerdns.com.");
+ BOOST_CHECK(w.chopOff());
+ BOOST_CHECK_EQUAL(w.toString(), "com.");
+ BOOST_CHECK(w.chopOff());
+ BOOST_CHECK_EQUAL(w.toString(), ".");
+ BOOST_CHECK(!w.chopOff());
+ BOOST_CHECK(!w.chopOff());
+
+ w.prependRawLabel("net");
+ w.prependRawLabel("root-servers");
+ w.prependRawLabel("a");
+ BOOST_CHECK_EQUAL(w.toString(), "a.root-servers.net.");
+}
+
+BOOST_AUTO_TEST_CASE(test_packetParse) {
+ vector<unsigned char> packet;
+ DNSPacketWriter dpw(packet, "www.ds9a.nl.", QType::AAAA);
+
+ uint16_t qtype;
+ DNSName dn((char*)&packet[12], packet.size() - 12, &qtype);
+ BOOST_CHECK_EQUAL(dn.toString(), "www.ds9a.nl.");
+ BOOST_CHECK_EQUAL(qtype, QType::AAAA);
+}
+
+BOOST_AUTO_TEST_SUITE_END()