]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
implement DNSName class, test the hell out of it, make dnsdist use it
authorbert hubert <bert.hubert@netherlabs.nl>
Thu, 19 Feb 2015 08:47:30 +0000 (09:47 +0100)
committerbert hubert <bert.hubert@netherlabs.nl>
Thu, 19 Feb 2015 08:47:30 +0000 (09:47 +0100)
pdns/Makefile.am
pdns/dnsdist.cc
pdns/dnslabeltext.rl
pdns/dnsname.cc [new file with mode: 0644]
pdns/dnsname.hh [new file with mode: 0644]
pdns/dnswriter.hh
pdns/test-dnsname_cc.cc [new file with mode: 0644]

index 910898bfcb5fafe02e2812804266ebffe0f1c3b3..257afb0be5ed41d2e8789d69903f80fc97a44721 100644 (file)
@@ -562,12 +562,14 @@ dnsdist_SOURCES = \
        dns.cc \
        dnsdist.cc \
        dnslabeltext.cc \
+       dnsname.cc dnsname.hh \
        dnsparser.cc dnsparser.hh \
        dnsrecords.cc \
        dnswriter.cc dnswriter.hh \
        dolog.hh \
        iputils.cc \
        logger.cc \
+       lua-iputils.cc \
        misc.cc misc.hh \
        nsecrecords.cc \
        qtype.cc \
@@ -582,7 +584,7 @@ dnsdist_LDFLAGS = \
        $(BOOST_PROGRAM_OPTIONS_LDFLAGS)
 
 dnsdist_LDADD = \
-       $(POLARSSL_LIBS) \
+       $(POLARSSL_LIBS) -lreadline \
        $(BOOST_PROGRAM_OPTIONS_LIBS) $(LUA_LIBS)
 
 nsec3dig_SOURCES = \
@@ -953,6 +955,11 @@ testrunner_SOURCES = \
        ueberbackend.cc \
        zoneparser-tng.cc zoneparser-tng.hh
 
+if CXX2011
+testrunner_SOURCES += test-dnsname_cc.cc dnsname.cc dnsname.hh
+endif
+
+
 testrunner_LDFLAGS = \
        $(AM_LDFLAGS) \
        $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) \
index c84e53b9a8a950981a0926294937f6f3028ab662..ebd09c9cb8453484e85ed02eb94c2bc10a58cb6b 100644 (file)
 #include <atomic>
 #include "arguments.hh"
 #include "dolog.hh"
+#include <readline/readline.h>
+#include <readline/history.h>
+#include "dnsname.hh"
 #include <fstream>
+
 #undef L
 
 
@@ -100,8 +104,11 @@ struct IDState
   atomic<uint64_t> age;
 };
 
-struct DownstreamState 
+
+struct DownstreamState
 {
+  DownstreamState(const ComboAddress& remote_);
+
   int fd;            
   thread tid;
   ComboAddress remote;
@@ -111,13 +118,19 @@ struct DownstreamState
   atomic<uint64_t> outstanding{0};
   atomic<uint64_t> reuseds{0};
   atomic<uint64_t> queries{0};
+  struct {
+    atomic<uint64_t> idOffset{0};
+    atomic<uint64_t> sendErrors{0};
+    atomic<uint64_t> outstanding{0};
+    atomic<uint64_t> reuseds{0};
+    atomic<uint64_t> queries{0};
+  } prev;
 };
 
-DownstreamState* g_dstates;
-unsigned int g_numdownstreams;
+vector<std::shared_ptr<DownstreamState> > g_dstates;
 
 // listens on a dedicated socket, lobs answers from downstream servers to original requestors
-void* responderThread(DownstreamState* state)
+void* responderThread(std::shared_ptr<DownstreamState> state)
 {
   char packet[4096];
   
@@ -147,6 +160,19 @@ void* responderThread(DownstreamState* state)
   return 0;
 }
 
+DownstreamState::DownstreamState(const ComboAddress& remote_)
+{
+  remote = remote_;
+  
+  fd = SSocket(remote.sin4.sin_family, SOCK_DGRAM, 0);
+  SConnect(fd, remote);
+  
+  idStates.resize(g_maxOutstanding);
+  
+  warnlog("Added downstream server %s", remote.toStringWithPort());
+}
+
+
 struct ClientState
 {
   ComboAddress local;
@@ -154,39 +180,13 @@ struct ClientState
   int tcpFD;
 };
 
-#if 0
-DownstreamState& getBestDownstream()
-{
-  unsigned int lowest = std::numeric_limits<unsigned int>::max();
-  unsigned int chosen = 0;
-  for(unsigned int n = 0; n < g_numdownstreams; ++n) {
-    if(g_dstates[n].outstanding < lowest) {
-      chosen = n;
-      lowest=g_dstates[n].outstanding;
-    }
-  }      
-  return g_dstates[chosen];
-}
-#endif
 LuaContext g_lua;
-class Object {
-public:
-  Object() : value(10) {}
-
-  void increment() { std::cout << "incrementing" << std::endl; value++; } 
-
-  int value;
-};
-
 
-DownstreamState& getBestDownstream()
+DownstreamState& getBestDownstream(const ComboAddress& remote, const DNSName& qname, uint16_t qtype)
 {
-  auto pickServer=g_lua.readVariable<LuaContext::LuaFunctionCaller<int (void)> >("pickServer");
-  auto i = pickServer();
-  return g_dstates[i];
-
-
-
+  //auto pickServer=g_lua.readVariable<LuaContext::LuaFunctionCaller<std::shared_ptr<DownstreamState> (void)> >("pickServer");
+  auto pickServer=g_lua.readVariable<std::function<std::shared_ptr<DownstreamState>(ComboAddress, DNSName, uint16_t)> >("pickServer");
+  return *pickServer(remote, DNSName(qname), qtype);
 }
 
 static void daemonize(void)
@@ -227,20 +227,24 @@ try
   if(g_vm.count("regex-drop"))
     re=new Regex(g_vm["regex-drop"].as<string>());
 
+  auto blockFilter=g_lua.readVariable<std::function<bool(ComboAddress, DNSName, uint16_t)> >("blockFilter");
+
   for(;;) {
     len = recvfrom(cs->udpFD, packet, sizeof(packet), 0, (struct sockaddr*) &remote, &socklen);
     if(len < (int)sizeof(struct dnsheader)) 
       continue;
 
-    if(re) {
-      qname=questionExpand(packet, len, qtype); 
-      if(re->match(qname)) {
-       g_regexBlocks++;
-       continue;
-      }
+
+    DNSName qname(packet+12, len-12, &qtype);
+    if(blockFilter(remote, qname, qtype))
+      continue;
+    if(re && re->match(qname.toString())) {
+      g_regexBlocks++;
+      continue;
     }
+   
 
-    DownstreamState& ss = getBestDownstream();
+    DownstreamState& ss = getBestDownstream(remote, qname, qtype);
     ss.queries++;
 
     unsigned int idOffset = (ss.idOffset++) % g_maxOutstanding;
@@ -297,9 +301,9 @@ catch(...)
    Let's start naively.
 */
 
-int getTCPDownstream(DownstreamState** ds)
+int getTCPDownstream(DownstreamState** ds, const ComboAddress& remote, const std::string& qname, uint16_t qtype)
 {
-  *ds = &getBestDownstream();
+  *ds = &getBestDownstream(remote, qname, qtype);
   
   vinfolog("TCP connecting to downstream %s", (*ds)->remote.toStringWithPort());
   int sock = SSocket((*ds)->remote.sin4.sin_family, SOCK_STREAM, 0);
@@ -385,13 +389,16 @@ void* tcpClientThread(int pipefd)
   
   for(;;) {
     ConnectionInfo* citmp, ci;
+    ComboAddress fixme;
     readn2(pipefd, &citmp, sizeof(citmp));
     --g_tcpclientthreads.d_queued;
     ci=*citmp;
     delete citmp;
      
-    if(dsock == -1)
-      dsock = getTCPDownstream(&ds);
+    if(dsock == -1) {
+
+      dsock = getTCPDownstream(&ds, fixme, "", 0);
+    }
     else {
       vinfolog("Reusing existing TCP connection to %s", ds->remote.toStringWithPort());
     }
@@ -411,7 +418,7 @@ void* tcpClientThread(int pipefd)
         if(!putMsgLen(dsock, qlen)) {
          vinfolog("Downstream connection to %s died on us, getting a new one!", ds->remote.toStringWithPort());
           close(dsock);
-          dsock=getTCPDownstream(&ds);
+          dsock=getTCPDownstream(&ds, fixme, "", 0);
           goto retry;
         }
       
@@ -420,7 +427,7 @@ void* tcpClientThread(int pipefd)
         if(!getMsgLen(dsock, &rlen)) {
          vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->remote.toStringWithPort());
           close(dsock);
-          dsock=getTCPDownstream(&ds);
+          dsock=getTCPDownstream(&ds, fixme, "", 0);
           goto retry;
         }
 
@@ -471,18 +478,13 @@ void* tcpAcceptorThread(void* p)
 }
 
 
-void* statThread()
+void* maintThread()
 {
   int interval = 1;
   if(!interval)
     return 0;
   uint32_t lastQueries=0;
 
-  uint64_t pqueries[g_numdownstreams];
-
-  for(unsigned int n=0; n < g_numdownstreams; ++n) 
-    pqueries[n] = g_dstates[n].queries.load();
-
   for(;;) {
     sleep(interval);
     
@@ -491,21 +493,19 @@ void* statThread()
 
     unsigned int outstanding=0;
     uint64_t numQueries=0;
-    for(unsigned int n=0; n < g_numdownstreams; ++n) {
-      DownstreamState& dss = g_dstates[n];
-
-      vinfolog(" %s: %d outstanding, %f qps", dss.remote.toStringWithPort(), dss.outstanding.load(), ((dss.queries.load() - pqueries[n])/interval));
+    for(auto& dss : g_dstates) {
+      vinfolog(" %s: %d outstanding, %f qps", dss->remote.toStringWithPort(), dss->outstanding.load(), ((dss->queries.load() - dss->prev.queries.load())/interval));
 
-      outstanding += dss.outstanding;
-      pqueries[n]=dss.queries.load();
-      numQueries += dss.queries;
-      for(unsigned int i=0 ; i < g_maxOutstanding; ++i) {
-        IDState& ids = dss.idStates[i];
+      outstanding += dss->outstanding;
+      dss->prev.queries.store(dss->queries.load());
+      numQueries += dss->queries;
+      
+      for(IDState& ids  : dss->idStates) {
         if(ids.origFD >=0 && ids.age++ > 2) {
           ids.age = 0;
           ids.origFD = -1;
-          dss.reuseds++;
-          --dss.outstanding;
+          dss->reuseds++;
+          --dss->outstanding;
         }          
       }
     }
@@ -517,42 +517,60 @@ void* statThread()
 }
 
 
-struct Server
-{
-  string d_name;
-  ComboAddress d_address;
-};
-vector<Server> g_servers;
-int defineServer(const std::string& name, const std::string& address)
-{
-  g_servers.push_back({name, ComboAddress(address, 53)});
-  return g_servers.size()-1;
-}
 
 void setupLua()
 {
-  g_lua.writeVariable("defineServer", &defineServer);
-  std::ifstream ifs("dnsdistconf.lua");
-  g_lua.executeCode(ifs);
+  g_lua.writeFunction("newServer", 
+                     [](const std::string& address)
+                     { 
+                       auto ret=std::shared_ptr<DownstreamState>(new DownstreamState(ComboAddress(address, 53)));
+                       ret->tid = move(thread(responderThread, ret));
+                       g_dstates.push_back(ret);
+                       return ret;
+                     } );
+
+  g_lua.writeFunction("deleteServer", 
+                     [](std::shared_ptr<DownstreamState> rem)
+                     { 
+                       g_dstates.erase(remove(g_dstates.begin(), g_dstates.end(), rem), g_dstates.end());
+                     } );
+
+
+  g_lua.writeFunction("listServers", []() {  
+      string ret;
+      for(auto& s : g_dstates) {
+       if(!ret.empty()) ret+="\n";
+       ret+=s->remote.toStringWithPort() + " " + std::to_string(s->queries.load()) + " " + std::to_string(s->outstanding.load());
+      }
+      return ret;
+    });
 
-  Object o1, o2;
-  g_lua.registerFunction("increment", &Object::increment);
 
-  g_lua.writeVariable("obj1", o1);
-  g_lua.writeVariable("obj2", o2);
-  g_lua.executeCode("obj1:increment();");
-  g_lua.executeCode("obj1:increment();");
+  g_lua.writeFunction("getServers", []() {
+      vector<pair<int, std::shared_ptr<DownstreamState> > > ret;
+      int count=1;
+      for(auto& s : g_dstates) {
+       ret.push_back(make_pair(count++, s));
+      }
+      return ret;
+    });
+
+  g_lua.registerFunction<string(DownstreamState::*)()>("tostring", [](const DownstreamState& s) { return s.remote.toStringWithPort(); });
+  
+  std::ifstream ifs("dnsdistconf.lua");
+  g_lua.registerFunction("tostring", &ComboAddress::toString);
 
-  std::cout << g_lua.readVariable<Object>("obj1").value << std::endl;
-  std::cout << g_lua.readVariable<Object>("obj2").value << std::endl;
+  g_lua.registerFunction("isPartOf", &DNSName::isPartOf);
+  g_lua.registerFunction("tostring", &DNSName::toString);
+  g_lua.writeFunction("newDNSName", [](const std::string& name) { return DNSName(name); });
 
+  g_lua.executeCode(ifs);
 }
 
+
 int main(int argc, char** argv)
 try
 {
-
-  setupLua();
   signal(SIGPIPE, SIG_IGN);
   openlog("dnsdist", LOG_PID, LOG_DAEMON);
   g_console=true;
@@ -584,13 +602,15 @@ try
   g_verbose=g_vm.count("verbose");
   g_maxOutstanding = g_vm["max-outstanding"].as<uint16_t>();
 
-  /*  
-  if(!g_vm.count("remotes")) {
-    cerr<<"Need to specify at least one remote address"<<endl;
-    cout<<desc<<endl;
-    exit(EXIT_FAILURE);
+  setupLua();
+
+  if(g_vm.count("remotes")) {
+    for(const auto& address : g_vm["remotes"].as<vector<string>>()) {
+      auto ret=std::shared_ptr<DownstreamState>(new DownstreamState(ComboAddress(address, 53)));
+      ret->tid = move(thread(responderThread, ret));
+      g_dstates.push_back(ret);
+    }
   }
-  */
 
   if(g_vm["daemon"].as<bool>())  {
     g_console=false;
@@ -600,26 +620,6 @@ try
     vinfolog("Running in the foreground");
   }
 
-  //  vector<string> remotes = g_vm["remotes"].as<vector<string> >();
-
-  g_numdownstreams = g_servers.size();
-  g_dstates = new DownstreamState[g_numdownstreams];
-  int pos=0;
-  for(const Server& server : g_servers) {
-    DownstreamState& dss = g_dstates[pos++];
-    dss.remote = server.d_address;
-
-    dss.fd = SSocket(dss.remote.sin4.sin_family, SOCK_DGRAM, 0);
-    SConnect(dss.fd, dss.remote);
-
-    dss.idStates.resize(g_maxOutstanding);
-
-    infolog("Added downstream server %s", dss.remote.toStringWithPort());
-
-    dss.tid = move(thread(responderThread, &dss));
-  }
-
   vector<string> locals;
   if(g_vm.count("local"))
     locals = g_vm["local"].as<vector<string> >();
@@ -627,7 +627,6 @@ try
     locals.push_back("::");
 
   for(const string& local : locals) {
-    cerr<<local<<endl;
     ClientState* cs = new ClientState;
     cs->local= ComboAddress(local, 53);
     cs->udpFD = SSocket(cs->local.sin4.sin_family, SOCK_DGRAM, 0);
@@ -662,15 +661,41 @@ try
     t1.detach();
   }
 
-  thread stattid(statThread);
-  stattid.join();
+  thread stattid(maintThread);
+  stattid.detach();
+
+  for(;;) {
+    char* sline = readline("> ");
+    if(!sline)
+      break;
+
+    string line(sline);
+    if(!line.empty())
+      add_history(sline);
+
+    free(sline);
+
+    
+    if(line=="quit")
+      break;
+
+    try {
+      g_lua.executeCode(line);
+    }
+    catch(std::exception& e) {
+      cerr<<"Error: "<<e.what()<<endl;
+    }
+    
+  }
+
+  // stattid.join();
 }
-/*
+
 catch(std::exception &e)
 {
   errlog("Fatal: %s", e.what());
 }
-*/
+
 catch(PDNSException &ae)
 {
   errlog("Fatal: %s", ae.reason);
index 58cdb9154d73f24831c5006ea61feeff20ef7067..5434af56ef208933d7fdd8b8f46c0362d469b9f8 100644 (file)
@@ -15,6 +15,7 @@ void appendSplit(vector<string>& ret, string& segment, char c)
   }
   segment.append(1, c);
 }
+
 }
 
 vector<string> segmentDNSText(const string& input )
@@ -80,6 +81,70 @@ vector<string> segmentDNSText(const string& input )
         return ret;
 };
 
+vector<string> segmentDNSName(const string& input )
+{
+%%{
+        machine dnsname;
+        write data;
+        alphtype unsigned char;
+}%%
+       (void)dnsname_error;  // silence warnings
+       (void)dnsname_en_main;
+        const char *p = input.c_str(), *pe = input.c_str() + input.length();
+        const char* eof = pe;
+        int cs;
+        char val = 0;
+
+        string label;
+        vector<string> ret;
+
+        %%{
+                action labelEnd { 
+                        ret.push_back(label);
+                        label.clear();
+                }
+                action labelBegin { 
+                        label.clear();
+                }
+
+                action reportEscaped {
+                  char c = *fpc;
+                  label.append(1, c);
+                }
+                action reportEscapedNumber {
+                  char c = *fpc;
+                  val *= 10;
+                  val += c-'0';
+                  
+                }
+                action doneEscapedNumber {
+                  label.append(1, val);
+                  val=0;
+                }
+                
+                action reportPlain {
+                  label.append(1, *(fpc));
+                }
+
+                escaped = '\\' (([^0-9]@reportEscaped) | ([0-9]{3}$reportEscapedNumber%doneEscapedNumber));
+                plain = (ascii-'\\'-'"'-'.') $ reportPlain;
+                labelElement = escaped | plain;            
+
+                main := ((labelElement+ '.') >labelBegin %labelEnd)+;
+
+                # Initialize and execute.
+                write init;
+                write exec;
+        }%%
+
+        if ( cs < dnsname_first_final ) {
+                throw runtime_error("Unable to parse DNS name '"+input+"'");
+        }
+
+        return ret;
+};
+
+
 #if 0
 int main()
 {
diff --git a/pdns/dnsname.cc b/pdns/dnsname.cc
new file mode 100644 (file)
index 0000000..19a32fa
--- /dev/null
@@ -0,0 +1,113 @@
+#include "dnsname.hh"
+#include <boost/format.hpp>
+#include <string>
+#include "dnswriter.hh"
+#include "misc.hh"
+
+DNSName::DNSName(const char* p)
+{
+  auto vec = segmentDNSName(p);
+  for(auto& e : vec)
+    d_labels.push_back(e);
+}
+
+DNSName::DNSName(const char* pos, int len, uint16_t* qtype)
+{
+  unsigned char labellen;
+  const char* end = pos + len;
+  while((labellen=*pos++) && pos < end) { // "scan and copy"
+    d_labels.push_back(string(pos, labellen));
+    pos+=labellen;
+  }
+  if(qtype && pos + labellen + 2 <= end)  
+    *qtype=(*(const unsigned char*)pos)*256 + *((const unsigned char*)pos+1);
+
+}
+
+std::string DNSName::toString() const
+{
+  if(d_labels.empty())  // I keep wondering if there is some deeper meaning to the need to do this
+    return ".";
+  std::string ret;
+  for(const auto& s : d_labels) {
+    ret+= escapeLabel(s) + ".";
+  }
+  return ret;
+}
+
+std::string DNSName::toDNSString() const
+{
+  std::string ret;
+  for(const auto& s : d_labels) {
+    ret.append(1, (char) s.length());
+    ret.append(s);
+  }
+  ret.append(1, (char)0);
+  return ret;
+}
+
+
+bool DNSName::isPartOf(const DNSName& parent) const
+{
+  auto us = d_labels.crbegin();
+  auto p = parent.d_labels.crbegin();
+  for(; us != d_labels.crend() && p != parent.d_labels.crend(); ++us, ++p) {
+    if(!pdns_iequals(*p, *us))
+      break;
+  }
+  return (p==parent.d_labels.crend());
+}
+
+void DNSName::appendRawLabel(const std::string& label)
+{
+  d_labels.push_back(label);
+}
+
+void DNSName::prependRawLabel(const std::string& label)
+{
+  d_labels.push_front(label);
+}
+
+deque<string> DNSName::getRawLabels() const
+{
+  return d_labels;
+}
+
+bool DNSName::chopOff() 
+{
+  if(d_labels.empty())
+    return false;
+  d_labels.pop_front();
+  return true;
+}
+
+bool DNSName::operator==(const DNSName& rhs) const
+{
+  if(rhs.d_labels.size() != d_labels.size())
+    return false;
+
+  auto us = d_labels.crbegin();
+  auto p = rhs.d_labels.crbegin();
+  for(; us != d_labels.crend() && p != rhs.d_labels.crend(); ++us, ++p) {
+    if(!pdns_iequals(*p, *us))
+      return false;
+  }
+  return true;
+}
+
+string DNSName::escapeLabel(const std::string& label)
+{
+  string ret;
+  for(uint8_t p : label) {
+    if(p=='.') 
+      ret+="\\.";
+    else if(p=='\\')
+      ret+="\\\\";
+    else if(p > 0x21 && p < 0x7e)
+      ret.append(1, (char)p);
+    else {
+      ret+="\\" + (boost::format("%03o") % (unsigned int)p).str();
+    }
+  }
+  return ret;
+}
diff --git a/pdns/dnsname.hh b/pdns/dnsname.hh
new file mode 100644 (file)
index 0000000..d6b1c44
--- /dev/null
@@ -0,0 +1,39 @@
+#pragma once
+#include <string>
+#include <deque>
+
+/* Quest in life: 
+     accept escaped ascii presentations of DNS names and store them "natively"
+     accept a DNS packet with an offset, and extract a DNS name from it
+     build up DNSNames with prepend and append of 'raw' unescaped labels
+
+   Be able to turn them into ASCII and "DNS name in a packet" again on request
+
+   Provide some common operators for comparison, detection of being part of another domain 
+
+   NOTE: For now, everything MUST be . terminated, otherwise it is an error
+*/
+
+// As a side note, we currently store the labels in a fancy deque<>, but we could go for native format storage easily
+
+class DNSName
+{
+public:
+  DNSName() {}                 //!< Constructs the root name
+  DNSName(const char* p);      //!< Constructs from a human formatted, escaped presentation
+  DNSName(const std::string& str) : DNSName(str.c_str()) {}   //!< Constructs from a human formatted, escaped presentation
+  DNSName(const char* p, int len, uint16_t* qtype); //!< Construct from a DNS Packet, taking the first question
+  
+  bool isPartOf(const DNSName& rhs) const;   //!< Are we part of the rhs name?
+  bool operator==(const DNSName& rhs) const; //!< DNS-native comparison (case insensitive)
+  std::string toString() const;              //!< Our human-friendly, escaped, representation
+  std::string toDNSString() const;           //!< Our representation in DNS native format
+  void appendRawLabel(const std::string& str); //!< Append this unescaped label
+  void prependRawLabel(const std::string& str); //!< Prepend this unescaped label
+  std::deque<std::string> getRawLabels() const; //!< Individual raw unescaped labels
+  bool chopOff();                               //!< Turn www.powerdns.com. into powerdns.com., returns false for .
+private:
+  std::deque<std::string> d_labels;
+  static std::string escapeLabel(const std::string& orig);
+  static std::string unescapeLabel(const std::string& orig);
+};
index 70bc50e8e848af917342cca4c4489b7a41af3715..a1ea9ddf526109df897fa51d658fd216c993dfdb 100644 (file)
@@ -132,4 +132,5 @@ private:
 typedef vector<pair<string::size_type, string::size_type> > labelparts_t;
 bool labeltokUnescape(labelparts_t& parts, const string& label);
 std::vector<string> segmentDNSText(const string& text); // from dnslabeltext.rl
+std::vector<string> segmentDNSName(const string& input ); // from dnslabeltext.rl
 #endif
diff --git a/pdns/test-dnsname_cc.cc b/pdns/test-dnsname_cc.cc
new file mode 100644 (file)
index 0000000..440bf61
--- /dev/null
@@ -0,0 +1,120 @@
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+#include <boost/test/unit_test.hpp>
+#include <boost/assign/std/map.hpp>
+#include "dnsname.hh"
+#include "misc.hh"
+#include "dnswriter.hh"
+#include "dnsrecords.hh"
+using namespace boost;
+using std::string;
+
+BOOST_AUTO_TEST_SUITE(dnsname_cc)
+
+BOOST_AUTO_TEST_CASE(test_basic) {
+  string before("www.ds9a.nl.");
+  DNSName b(before);
+
+  BOOST_CHECK_EQUAL(b.getRawLabels().size(), 3);
+  string after(b.toString());
+  BOOST_CHECK_EQUAL(before, after);
+
+  DNSName wwwds9anl("www.ds9a.nl.");
+  DNSName nl("nl.");
+  BOOST_CHECK(wwwds9anl.isPartOf(nl));
+  BOOST_CHECK(wwwds9anl.isPartOf(wwwds9anl));
+
+  BOOST_CHECK(!nl.isPartOf(wwwds9anl));
+
+  BOOST_CHECK(wwwds9anl == wwwds9anl);
+
+  BOOST_CHECK(DNSName("wWw.ds9A.Nl.") == DNSName("www.ds9a.nl."));
+  BOOST_CHECK(DNSName("www.ds9a.nl.") == DNSName("www.ds9a.nl."));
+
+  BOOST_CHECK(DNSName("www.ds9a.nl.").toString() == "www.ds9a.nl.");
+
+  DNSName left("ds9a.nl.");
+  left.prependRawLabel("www");
+  BOOST_CHECK( left == DNSName("WwW.Ds9A.Nl."));
+
+  left.appendRawLabel("com");
+
+  BOOST_CHECK( left == DNSName("WwW.Ds9A.Nl.com."));
+  
+  DNSName root;
+  BOOST_CHECK(root.toString() == ".");
+
+  root.appendRawLabel("www");
+  root.appendRawLabel("powerdns.com");
+  root.appendRawLabel("com");
+
+  BOOST_CHECK_EQUAL(root.toString(), "www.powerdns\\.com.com.");
+
+  DNSName rfc4343_2_2(R"(Donald\032E\.\032Eastlake\0323rd.example.)");
+  DNSName example("example.");
+  BOOST_CHECK(rfc4343_2_2.isPartOf(example));
+
+  auto labels=rfc4343_2_2.getRawLabels();
+  BOOST_CHECK_EQUAL(*labels.begin(), "Donald E. Eastlake 3rd");
+  BOOST_CHECK_EQUAL(*labels.rbegin(), "example");
+  BOOST_CHECK_EQUAL(labels.size(), 2);
+
+  try {
+    DNSName broken("bert..hubert.");
+    BOOST_CHECK(0);
+  }catch(...){}
+
+  DNSName n;
+  n.appendRawLabel("powerdns.dnsmaster");
+  n.appendRawLabel("powerdns");
+  n.appendRawLabel("com");
+
+  BOOST_CHECK_EQUAL(n.toString(), "powerdns\\.dnsmaster.powerdns.com.");
+
+  BOOST_CHECK_EQUAL(DNSName().toString(), ".");
+
+  DNSName p;
+  string label("power");
+  label.append(1, (char)0);
+  label.append("dns");
+  p.appendRawLabel(label);
+  p.appendRawLabel("com");
+
+  BOOST_CHECK_EQUAL(p.toString(), "power\\000dns.com.");
+}
+
+
+BOOST_AUTO_TEST_CASE(test_dnsstrings) {
+  DNSName w("www.powerdns.com.");
+  BOOST_CHECK_EQUAL(w.toDNSString(), string("\003www\010powerdns\003com\000", 18));
+}
+
+BOOST_AUTO_TEST_CASE(test_chopping) {
+  DNSName w("www.powerdns.com.");
+  BOOST_CHECK_EQUAL(w.toString(), "www.powerdns.com.");
+  BOOST_CHECK(w.chopOff());
+  BOOST_CHECK_EQUAL(w.toString(), "powerdns.com.");
+  BOOST_CHECK(w.chopOff());
+  BOOST_CHECK_EQUAL(w.toString(), "com.");
+  BOOST_CHECK(w.chopOff());
+  BOOST_CHECK_EQUAL(w.toString(), ".");
+  BOOST_CHECK(!w.chopOff());
+  BOOST_CHECK(!w.chopOff());
+
+  w.prependRawLabel("net");
+  w.prependRawLabel("root-servers");
+  w.prependRawLabel("a");
+  BOOST_CHECK_EQUAL(w.toString(), "a.root-servers.net.");
+}
+
+BOOST_AUTO_TEST_CASE(test_packetParse) {
+  vector<unsigned char> packet;
+  DNSPacketWriter dpw(packet, "www.ds9a.nl.", QType::AAAA);
+
+  uint16_t qtype;
+  DNSName dn((char*)&packet[12], packet.size() - 12, &qtype);
+  BOOST_CHECK_EQUAL(dn.toString(), "www.ds9a.nl.");
+  BOOST_CHECK_EQUAL(qtype, QType::AAAA);
+}
+
+BOOST_AUTO_TEST_SUITE_END()