]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Add EDNS Client Subnet to queries, unit tests and regression tests.
authorRemi Gacogne <rgacogne-github@coredump.fr>
Thu, 3 Dec 2015 08:28:22 +0000 (09:28 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 7 Dec 2015 08:23:59 +0000 (09:23 +0100)
The main idea is to be able to add EDNS Client Subnet information
to queries to that the downstream servers know who the real client
is. This is easy if the query already has EDNS, since we just have
to add an ECS option to the existing pseudo-RR OPT. Note that if the
query has an existing ECS option, we just leave it as it is unless
setECSOverride has been set to true.
If the original query has no ENDS support, we can easily add the
pseudo-RR OPT with an ECS option, but we need to remove EDNS from
the response, otherwise the client might choke on it. This is not too
hard it the pseudo-RR is the last one, but otherwise removing it
could cause issue with name compression (offsets), therefore we
completely rewrite the response in that case.

This implementation tries hard not to parse the query, and not
to allocate too much objects to parse the response. The huge downside
is that we do not reuse as much parsing code as we probably could.

Unit tests have been added, they are not enabled unless
--enable-unit-tests is set.

Regression tests have been added in regression-tests.dnsdist.

Unit and regression tests are enabled in travis.

29 files changed:
.travis.yml
pdns/README-dnsdist.md
pdns/dns.hh
pdns/dnsdist-ecs.cc [new file with mode: 0644]
pdns/dnsdist-ecs.hh [new file with mode: 0644]
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/configure.ac
pdns/dnsdistdist/dnsdist-ecs.cc [new symlink]
pdns/dnsdistdist/dnsdist-ecs.hh [new symlink]
pdns/dnsdistdist/dnsparser.cc [new symlink]
pdns/dnsdistdist/ednssubnet.cc [new symlink]
pdns/dnsdistdist/ednssubnet.hh [new symlink]
pdns/dnsdistdist/m4/pdns_enable_unit_tests.m4 [new symlink]
pdns/dnsdistdist/test-dnsdist_cc.cc [new symlink]
pdns/dnsdistdist/testrunner.cc [new file with mode: 0644]
pdns/dnsrulactions.hh
pdns/test-dnsdist_cc.cc [new file with mode: 0644]
regression-tests.dnsdist/.gitignore [new file with mode: 0644]
regression-tests.dnsdist/clientsubnetoption.py [new file with mode: 0644]
regression-tests.dnsdist/dnsdist.conf [new file with mode: 0644]
regression-tests.dnsdist/dnsdisttests.py [new file with mode: 0644]
regression-tests.dnsdist/requirements.txt [new file with mode: 0644]
regression-tests.dnsdist/runtests [new file with mode: 0755]
regression-tests.dnsdist/test_Basics.py [new file with mode: 0644]
regression-tests.dnsdist/test_EdnsClientSubnet.py [new file with mode: 0644]

index 9839895aaed4a69baab437c91ed9a706383313f9..49cf57880bef1f1b47ca4c1093dd615a30117e33 100644 (file)
@@ -279,11 +279,15 @@ script:
   - cd pdns/dnsdistdist
   - tar xf dnsdist*.tar.bz2
   - cd dnsdist-*
-  - ./configure
+  - ./configure --enable-unit-tests
   - make -k -j3
+  - ./testrunner
+  - cp ./dnsdist ../../../regression-tests.dnsdist/
+  - cd ../../../regression-tests.dnsdist
+  - DNSDISTBIN=./dnsdist ./runtests
+  - rm -f ./dnsdist
   - cd ..
-  - rm -rf dnsdist-*/
-  - cd ../../
+  - rm -rf pdns/dnsdistdist/dnsdist-*/
 
 
   ### api ###
index b5366c397ab878b0f4196a2175a1b1cf185ab882..fefe3791c32a3a24a9c846e8c5e30a7afb2c1098 100644 (file)
@@ -145,6 +145,17 @@ its RCODE differs from NXDomain, ServFail and Refused.
 newServer {address="192.0.2.1", checkType="AAAA", checkName="a.root-servers.net.", mustResolve=true}
 ```
 
+In order to provide the downstream server with the address of the real client,
+or at least the one talking to dnsdist, the 'useClientSubnet' parameter can be used
+when declaring a new server. This parameter indicates whether an EDNS Client Subnet option
+should be added to the request. If the incoming request already contains an EDNS Client Subnet value,
+it will not be overriden unless setECSOverride is set to true. The source prefix-length may be
+configured with:
+```
+> setECSSourcePrefixV4(24)
+> setECSSourcePrefixV6(56)
+```
+
 TCP timeouts
 ------------
 
@@ -580,7 +591,7 @@ Here are all functions:
    * `errlog(string)`: log at level error
  * Server related:
    * `newServer("ip:port")`: instantiate a new downstream server with default settings
-   * `newServer({address="ip:port", qps=1000, order=1, weight=10, pool="abuse", retries=5, tcpSendTimeout=30, tcpRecvTimeout=30, checkName="a.root-servers.net.", checkType="A", mustResolve=false})`: 
+   * `newServer({address="ip:port", qps=1000, order=1, weight=10, pool="abuse", retries=5, tcpSendTimeout=30, tcpRecvTimeout=30, checkName="a.root-servers.net.", checkType="A", mustResolve=false, useClientSubnet=true})`:
 instantiate a server with additional parameters
    * `showServers()`: output all servers
    * `getServer(n)`: returns server with index n 
index ae0a2033dd9d6ec288359f1dc1152f56dc2e06a7..1cb12dcf35bafbba91dfbfb5473716034069793f 100644 (file)
@@ -215,11 +215,22 @@ struct dnsheader {
         unsigned        arcount :16;    /* number of resource entries */
 };
 
+static_assert(sizeof(dnsheader) == 12, "dnsheader size must be 12");
+
 inline uint16_t * getFlagsFromDNSHeader(struct dnsheader * dh)
 {
   return (uint16_t*) (((char *) dh) + sizeof(uint16_t));
 }
 
+#define DNS_TYPE_SIZE (2)
+#define DNS_CLASS_SIZE (2)
+#define DNS_TTL_SIZE (4)
+#define DNS_RDLENGTH_SIZE (2)
+#define EDNS_EXTENDED_RCODE_SIZE (1)
+#define EDNS_VERSION_SIZE (1)
+#define EDNS_OPTION_CODE_SIZE (2)
+#define EDNS_OPTION_LENGTH_SIZE (2)
+
 #if BYTE_ORDER == BIG_ENDIAN
 #define FLAGS_RD_OFFSET (8)
 #define FLAGS_CD_OFFSET (12)
diff --git a/pdns/dnsdist-ecs.cc b/pdns/dnsdist-ecs.cc
new file mode 100644 (file)
index 0000000..b0d5e2e
--- /dev/null
@@ -0,0 +1,411 @@
+
+#include "dolog.hh"
+#include "dnsdist.hh"
+#include "dnsdist-ecs.hh"
+#include "dnsparser.hh"
+#include "dnswriter.hh"
+#include "ednssubnet.hh"
+
+/* when we add EDNS to a query, we don't want to advertise
+   a large buffer size */
+size_t q_EdnsUDPPayloadSize = 512;
+/* draft-ietf-dnsop-edns-client-subnet-04 "11.1.  Privacy" */
+uint16_t g_ECSSourcePrefixV4 = 24;
+uint16_t g_ECSSourcePrefixV6 = 56;
+
+bool g_ECSOverride{false};
+
+int rewriteResponseWithoutEDNS(const char * packet, const size_t len, vector<uint8_t>& newContent)
+{
+  assert(packet != NULL);
+  assert(len >= sizeof(dnsheader));
+  const struct dnsheader* dh = (const struct dnsheader*) packet;
+
+  if (ntohs(dh->arcount) == 0)
+    return ENOENT;
+
+  if (ntohs(dh->qdcount) == 0)
+    return ENOENT;
+    
+  vector<uint8_t> content(len - sizeof(dnsheader));
+  copy(packet + sizeof(dnsheader), packet + len, content.begin());
+  PacketReader pr(content);
+  
+  size_t idx = 0;
+  DNSName rrname;
+  uint16_t qdcount = ntohs(dh->qdcount);
+  uint16_t ancount = ntohs(dh->ancount);
+  uint16_t nscount = ntohs(dh->nscount);
+  uint16_t arcount = ntohs(dh->arcount);
+  uint16_t rrtype;
+  uint16_t rrclass;
+  string blob;
+  struct dnsrecordheader ah;
+
+  rrname = pr.getName();
+  rrtype = pr.get16BitInt();
+  rrclass = pr.get16BitInt();
+  
+  DNSPacketWriter pw(newContent, rrname, rrtype, rrclass, dh->opcode);
+  pw.getHeader()->id=dh->id;
+  pw.getHeader()->qr=dh->qr;
+  pw.getHeader()->aa=dh->aa;
+  pw.getHeader()->tc=dh->tc;
+  pw.getHeader()->rd=dh->rd;
+  pw.getHeader()->ra=dh->ra;
+  pw.getHeader()->ad=dh->ad;
+  pw.getHeader()->cd=dh->cd;
+  pw.getHeader()->rcode=dh->rcode;
+  
+  /* consume remaining qd if any */
+  if (qdcount > 1) {
+    for(idx = 1; idx < qdcount; idx++) {
+      rrname = pr.getName();
+      rrtype = pr.get16BitInt();
+      rrclass = pr.get16BitInt();
+      (void) rrtype;
+      (void) rrclass;
+    }
+  }
+
+  /* copy AN and NS */
+  for (idx = 0; idx < ancount; idx++) {
+    rrname = pr.getName();
+    pr.getDnsrecordheader(ah);
+
+    pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
+    pr.xfrBlob(blob);
+    pw.xfrBlob(blob);
+  }
+
+  for (idx = 0; idx < nscount; idx++) {
+    rrname = pr.getName();
+    pr.getDnsrecordheader(ah);
+
+    pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
+    pr.xfrBlob(blob);
+    pw.xfrBlob(blob);
+  }
+
+  /* consume AR, looking for OPT */
+  for (idx = 0; idx < arcount; idx++) {
+    rrname = pr.getName();
+    pr.getDnsrecordheader(ah);
+
+    if (ah.d_type != QType::OPT) {
+      pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
+      pr.xfrBlob(blob);
+      pw.xfrBlob(blob);
+    } else {
+      pr.d_pos += ah.d_clen;
+    }
+  }
+  pw.commit();
+
+  return 0;
+}
+
+int locateEDNSOptRR(const char * packet, const size_t len, const char ** optStart, size_t * optLen, bool * last)
+{
+  assert(packet != NULL);
+  assert(optStart != NULL);
+  assert(optLen != NULL);
+  assert(last != NULL);
+  const struct dnsheader* dh = (const struct dnsheader*) packet;
+
+  if (ntohs(dh->arcount) == 0)
+    return ENOENT;
+
+  vector<uint8_t> content(len - sizeof(dnsheader));
+  copy(packet + sizeof(dnsheader), packet + len, content.begin());
+  PacketReader pr(content);
+  size_t idx = 0;
+  DNSName rrname;
+  uint16_t qdcount = ntohs(dh->qdcount);
+  uint16_t ancount = ntohs(dh->ancount);
+  uint16_t nscount = ntohs(dh->nscount);
+  uint16_t arcount = ntohs(dh->arcount);
+  uint16_t rrtype;
+  uint16_t rrclass;
+  struct dnsrecordheader ah;
+
+  /* consume qd */
+  for(idx = 0; idx < qdcount; idx++) {
+    rrname = pr.getName();
+    rrtype = pr.get16BitInt();
+    rrclass = pr.get16BitInt();
+    (void) rrtype;
+    (void) rrclass;
+  }
+
+  /* consume AN and NS */
+  for (idx = 0; idx < ancount + nscount; idx++) {
+    rrname = pr.getName();
+    pr.getDnsrecordheader(ah);
+    pr.d_pos += ah.d_clen;
+  }
+
+  /* consume AR, looking for OPT */
+  for (idx = 0; idx < arcount; idx++) {
+    uint16_t start = pr.d_pos;
+    rrname = pr.getName();
+    pr.getDnsrecordheader(ah);
+
+    if (ah.d_type == QType::OPT) {
+      *optStart = packet + sizeof(dnsheader) + start;
+      *optLen = (pr.d_pos - start) + ah.d_clen;
+
+      if ((packet + len) < (*optStart + *optLen)) {
+        throw std::range_error("Opt record overflow");
+      }
+
+      if (idx == ((size_t) arcount - 1)) {
+        *last = true;
+      }
+      else {
+        *last = false;
+      }
+      return 0;
+    }
+    pr.d_pos += ah.d_clen;
+  }
+
+  return ENOENT;
+}
+
+/* extract the start of the OPT RR in a QUERY packet if any */
+static int getEDNSOptionsStart(char* packet, const size_t offset, const size_t len, char ** optStart, size_t * remaining, uint16_t ** optRDLen)
+{
+  assert(packet != NULL);
+  assert(optStart != NULL);
+  assert(remaining != NULL);
+  const struct dnsheader* dh = (const struct dnsheader*) packet;
+  
+  if (offset >= len)
+    return ENOENT;
+
+  if (ntohs(dh->qdcount) != 1 || dh->ancount != 0 || ntohs(dh->arcount) != 1 || dh->nscount != 0)
+    return ENOENT;
+
+  size_t pos = sizeof(dnsheader) + offset;
+  pos += DNS_TYPE_SIZE + DNS_CLASS_SIZE;
+
+  if (pos >= len)
+    return ENOENT;
+
+  uint16_t qtype, qclass;
+  unsigned int consumed;
+  DNSName aname(packet, len, pos, true, &qtype, &qclass, &consumed);
+
+  pos += consumed + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
+  if(qtype != QType::OPT || (len - pos) < (DNS_TTL_SIZE + DNS_RDLENGTH_SIZE))
+    return ENOENT;
+
+  *optStart = packet + pos;
+  *remaining = len - pos;
+
+  if (optRDLen) {
+    *optRDLen = (uint16_t *)(packet + pos + DNS_TTL_SIZE);
+  }
+
+  return 0;
+}
+
+/* extract a specific EDNS0 option from a pointer on the beginning of the OPT RR */
+static int getEDNSOption(char* optRR, const size_t len, const uint16_t wantedOption, char ** optionValue, size_t * optionValueSize)
+{
+  assert(optRR != NULL);
+  assert(optionValue != NULL);
+  assert(optionValueSize != NULL);
+  size_t pos = 0;
+
+  pos += DNS_TTL_SIZE;
+  const uint16_t rdLen = ntohs(*((uint16_t*) (optRR + pos)));
+  size_t rdPos = 0;
+  pos += DNS_RDLENGTH_SIZE;  
+
+  while(pos < (len - ((size_t) EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)) &&
+        rdPos < (rdLen - ((size_t) EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))) {
+    const uint16_t optionCode = ntohs(*((uint16_t*) (optRR + pos)));
+    pos += EDNS_OPTION_CODE_SIZE;
+    rdPos += EDNS_OPTION_CODE_SIZE;
+    const uint16_t optionLen = ntohs(*((uint16_t*) (optRR + pos)));
+    pos += EDNS_OPTION_LENGTH_SIZE;
+    rdPos += EDNS_OPTION_LENGTH_SIZE;
+
+    if (optionLen > (rdLen - rdPos) || optionLen > (len - pos))
+      return EINVAL;
+    
+    if (optionCode == wantedOption) {
+      *optionValue = optRR + pos - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE);
+      *optionValueSize = optionLen + EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE;
+      return 0;
+    }
+    else {
+      /* skip this option */
+      pos += optionLen;
+      rdPos += optionLen;
+    }
+  }
+  
+  return ENOENT;
+}
+
+static void generateECSOption(const ComboAddress& source, string& res)
+{
+  const uint16_t ecsOptionCode = htons(EDNS0_OPTION_CODE_ECS);
+  Netmask sourceNetmask(source, source.sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6);
+  EDNSSubnetOpts ecsOpts;
+  ecsOpts.source = sourceNetmask;
+  string payload = makeEDNSSubnetOptsString(ecsOpts);
+  const uint16_t payloadLen = htons(payload.length());
+  res.append((const char *) &ecsOptionCode, sizeof ecsOptionCode);
+  res.append((const char *) &payloadLen, sizeof payloadLen);
+  res.append(payload);
+}
+
+static void generateECSOptRR(const ComboAddress& source, string & res)
+{
+  const uint8_t name = 0;
+  dnsrecordheader dh;
+  EDNS0Record edns0;
+  edns0.extRCode = 0;
+  edns0.version = 0;
+  edns0.Z = 0;
+  
+  dh.d_type = htons(QType::OPT);
+  dh.d_class = htons(q_EdnsUDPPayloadSize);
+  memcpy(&dh.d_ttl, &edns0, sizeof edns0);
+  string optRData;
+  generateECSOption(source, optRData);
+  dh.d_clen = htons((uint16_t) optRData.length());
+  res.assign((const char *) &name, sizeof name);
+  res.append((const char *) &dh, sizeof dh);
+  res.append(optRData.c_str(), optRData.length());
+}
+
+static void replaceEDNSClientSubnetOption(char * const packet, const size_t packetSize, int * const len, string& largerPacket, const ComboAddress& remote, char * const oldEcsOptionStart, size_t const oldEcsOptionSize, uint16_t * const optRDLen)
+{
+  assert(packet != NULL);
+  assert(len != NULL);
+  assert(oldEcsOptionStart != NULL);
+  assert(optRDLen != NULL);
+  string ECSOption;
+  generateECSOption(remote, ECSOption);
+
+  if (ECSOption.size() == oldEcsOptionSize) {
+    /* same size as the existing option */
+    memcpy(oldEcsOptionStart, ECSOption.c_str(), oldEcsOptionSize);
+  }
+  else {
+    /* different size than the existing option */
+    const unsigned int newPacketLen = *len + (ECSOption.length() - oldEcsOptionSize);
+    const size_t beforeOptionLen = oldEcsOptionStart - packet;
+    const size_t dataBehindSize = *len - beforeOptionLen - oldEcsOptionSize;
+          
+    /* fix the size of ECS Option RDLen */
+    uint16_t newRDLen = htons(*optRDLen);
+    newRDLen += (ECSOption.size() - oldEcsOptionSize);
+    *optRDLen = htons(newRDLen);
+    
+    if (newPacketLen <= packetSize) {
+      /* it fits in the existing buffer */
+      if (dataBehindSize > 0) {
+        memmove(oldEcsOptionStart, oldEcsOptionStart + oldEcsOptionSize, dataBehindSize);
+      }
+      memcpy(oldEcsOptionStart + dataBehindSize, ECSOption.c_str(), ECSOption.size());
+      *len = newPacketLen;
+    }
+    else {
+      /* We need a larger packet */
+      if (newPacketLen > largerPacket.capacity()) {
+        largerPacket.reserve(newPacketLen);
+      }
+      /* copy data before the existing option */
+      largerPacket.append(packet, beforeOptionLen);
+      /* copy the new option */
+      largerPacket.append(ECSOption);
+      /* copy data that where behind the existing option */
+      if (dataBehindSize > 0) {
+        largerPacket.append(oldEcsOptionStart + oldEcsOptionSize, dataBehindSize);
+      }
+    }
+  }
+}
+
+void handleEDNSClientSubnet(char * const packet, const size_t packetSize, const unsigned int consumed, int * const len, string& largerPacket, bool * const ednsAdded, const ComboAddress& remote)
+{
+  assert(packet != NULL);
+  assert(len != NULL);
+  assert(*len >= 0);
+  assert(consumed <= (size_t) *len);
+  assert(ednsAdded != NULL);
+  char * optRRStart = NULL;
+  uint16_t * optRDLen = NULL;
+  size_t remaining = 0;
+        
+  int res = getEDNSOptionsStart(packet, consumed, *len, &optRRStart, &remaining, &optRDLen);
+        
+  if (res == 0) {
+    char * ecsOptionStart = NULL;
+    size_t ecsOptionSize = 0;
+    
+    res = getEDNSOption(optRRStart, remaining, EDNS0_OPTION_CODE_ECS, &ecsOptionStart, &ecsOptionSize);
+    
+    if (res == 0) {
+      /* there is already an EDNS0_OPTION_CODE_ECS value */
+      if (g_ECSOverride) {
+        replaceEDNSClientSubnetOption(packet, packetSize, len, largerPacket, remote, ecsOptionStart, ecsOptionSize, optRDLen);
+      }
+    } else {
+      /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
+      /* getEDNSOptionsStart has already checked that there is exactly one AR,
+         no NS and no AN */
+      string ECSOption;
+      generateECSOption(remote, ECSOption);
+      const size_t ECSOptionSize = ECSOption.size();
+      
+      uint16_t newRDLen = htons(*optRDLen);
+      newRDLen += ECSOptionSize;
+      *optRDLen = htons(newRDLen);
+
+      if (packetSize - *len > ECSOptionSize) {
+        /* if the existing buffer is large enough */
+        memcpy(packet + *len, ECSOption.c_str(), ECSOptionSize);
+        *len += ECSOptionSize;
+      }
+      else {
+        if (*len + ECSOptionSize > largerPacket.capacity()) {
+          largerPacket.reserve(*len + ECSOptionSize);
+        }
+        
+        largerPacket.append(packet, *len);
+        largerPacket.append(ECSOption);
+      }
+    }
+  }
+  else {
+    /* we need to add a EDNS0 RR with one EDNS0 ECS option, fixing the AR count */
+    string EDNSRR;
+    struct dnsheader* dh = (struct dnsheader*) packet;
+    generateECSOptRR(remote, EDNSRR);
+    uint16_t arcount = ntohs(dh->arcount);
+    arcount++;
+    dh->arcount = htons(arcount);
+    *ednsAdded = true;
+
+    /* does it fit in the existing buffer? */
+    if (packetSize - *len > EDNSRR.size()) {
+      memcpy(packet + *len, EDNSRR.c_str(), EDNSRR.size());
+      *len += EDNSRR.size();
+    }
+    else {
+      if (*len + EDNSRR.size() > largerPacket.capacity()) {
+        largerPacket.reserve(*len + EDNSRR.size());
+      }
+      
+      largerPacket.append(packet, *len);
+      largerPacket.append(EDNSRR);
+    }
+  }
+}
diff --git a/pdns/dnsdist-ecs.hh b/pdns/dnsdist-ecs.hh
new file mode 100644 (file)
index 0000000..4edbac7
--- /dev/null
@@ -0,0 +1,8 @@
+#pragma once
+
+int rewriteResponseWithoutEDNS(const char * packet, size_t len, vector<uint8_t>& newContent);
+int locateEDNSOptRR(const char * packet, size_t len, const char ** optStart, size_t * optLen, bool * last);
+void handleEDNSClientSubnet(char * packet, size_t packetSize, unsigned int consumed, int * len, string& largerPacket, bool * ednsAdded, const ComboAddress& remote);
+
+
+
index 56fcf45a59dcf9f895a34c923b4e9e7404ea8327..4b482c2c7131c8f5ece9de72803b37c319b280fe 100644 (file)
@@ -181,6 +181,10 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
                          ret->mustResolve=boost::get<bool>(vars["mustResolve"]);
                        }
 
+                       if(vars.count("useClientSubnet")) {
+                         ret->useECS=boost::get<bool>(vars["useClientSubnet"]);
+                       }
+
                        if(g_launchWork) {
                          g_launchWork->push_back([ret]() {
                              ret->tid = move(thread(responderThread, ret));
@@ -910,6 +914,12 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
 
   g_lua.writeFunction("setMaxTCPClientThreads", [](uint64_t max) { g_maxTCPClientThreads = max; });
 
+  g_lua.writeFunction("setECSSourcePrefixV4", [](uint16_t prefix) { g_ECSSourcePrefixV4=prefix; });
+
+  g_lua.writeFunction("setECSSourcePrefixV6", [](uint16_t prefix) { g_ECSSourcePrefixV6=prefix; });
+
+  g_lua.writeFunction("setECSOverride", [](bool override) { g_ECSOverride=override; });
+
   g_lua.writeFunction("dumpStats", [] {
       vector<string> leftcolumn, rightcolumn;
 
index 9d6a0b81ee19a0a3603b3f7b3d4c740a439066a5..60936ad78891476b8419596835789314b8943925 100644 (file)
@@ -21,6 +21,7 @@
 */
 
 #include "dnsdist.hh"
+#include "dnsdist-ecs.hh"
 #include "dolog.hh"
 #include "lock.hh"
 #include <thread>
@@ -140,7 +141,9 @@ void* tcpClientThread(int pipefd)
     const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
     const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
     const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
-
+    string largerQuery;
+    vector<uint8_t> rewrittenResponse;
+    bool ednsAdded = false;
     shared_ptr<DownstreamState> ds;
     if (!setNonBlocking(ci.fd))
       goto drop;
@@ -155,10 +158,13 @@ void* tcpClientThread(int pipefd)
           break;
         }
 
-        char query[qlen];
-        readn2WithTimeout(ci.fd, query, qlen, g_tcpRecvTimeout);
+        char queryBuffer[qlen];
+        const char * query = queryBuffer;
+        size_t queryLen = qlen;
+        readn2WithTimeout(ci.fd, queryBuffer, queryLen, g_tcpRecvTimeout);
        uint16_t qtype;
-       DNSName qname(query, qlen, 12, false, &qtype);
+       unsigned int consumed = 0;
+       DNSName qname(query, queryLen, sizeof(dnsheader), false, &qtype, 0, &consumed);
        string ruleresult;
        struct dnsheader* dh =(dnsheader*)query;
        const uint16_t * flags = getFlagsFromDNSHeader(dh);
@@ -204,8 +210,8 @@ void* tcpClientThread(int pipefd)
        
        DNSAction::Action action=DNSAction::Action::None;
        for(const auto& lr : *localRulactions) {
-         if(lr.first->matches(ci.remote, qname, qtype, dh, qlen)) {
-           action=(*lr.second)(ci.remote, qname, qtype, dh, qlen, &ruleresult);
+         if(lr.first->matches(ci.remote, qname, qtype, dh, queryLen)) {
+           action=(*lr.second)(ci.remote, qname, qtype, dh, queryLen, &ruleresult);
            if(action != DNSAction::Action::None) {
              lr.first->d_matches++;
              break;
@@ -238,8 +244,8 @@ void* tcpClientThread(int pipefd)
        }
        
        if(dh->qr) { // something turned it into a response
-         if (putNonBlockingMsgLen(ci.fd, qlen, g_tcpSendTimeout))
-           writen2WithTimeout(ci.fd, query, rlen, g_tcpSendTimeout);
+         if (putNonBlockingMsgLen(ci.fd, queryLen, g_tcpSendTimeout))
+           writen2WithTimeout(ci.fd, query, queryLen, g_tcpSendTimeout);
 
          g_stats.selfAnswered++;
          goto drop;
@@ -254,6 +260,18 @@ void* tcpClientThread(int pipefd)
          g_stats.noPolicy++;
          break;
        }
+
+        if (ds->useECS) {
+          int newLen = queryLen;
+          handleEDNSClientSubnet(queryBuffer, queryLen, consumed, &newLen, largerQuery, &ednsAdded, ci.remote);
+          if (largerQuery.empty() == false) {
+            query = largerQuery.c_str();
+            queryLen = largerQuery.size();
+          } else {
+            queryLen = newLen;
+          }
+        }
+
        if(sockets.count(ds->remote) == 0) {
          dsock=sockets[ds->remote]=setupTCPDownstream(ds->remote);
        }
@@ -280,7 +298,7 @@ void* tcpClientThread(int pipefd)
           break;
         }
 
-        if(!putNonBlockingMsgLen(dsock, qlen, ds->tcpSendTimeout)) {
+        if(!putNonBlockingMsgLen(dsock, queryLen, ds->tcpSendTimeout)) {
          vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
           close(dsock);
           sockets[ds->remote]=dsock=setupTCPDownstream(ds->remote);
@@ -289,7 +307,7 @@ void* tcpClientThread(int pipefd)
         }
 
         try {
-          writen2WithTimeout(dsock, query, qlen, ds->tcpSendTimeout);
+          writen2WithTimeout(dsock, query, queryLen, ds->tcpSendTimeout);
         }
         catch(const runtime_error& e) {
           vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
@@ -317,11 +335,44 @@ void* tcpClientThread(int pipefd)
         origFlags &= ~restoreFlagsMask;
         /* set the saved flags as they were */
         *responseFlags |= origFlags;
+        char * response = answerbuffer;
+        size_t responseLen = rlen;
+
+        if (ednsAdded) {
+          const char * optStart = NULL;
+          size_t optLen = 0;
+          bool last = false;
+
+          int res = locateEDNSOptRR(response, responseLen, &optStart, &optLen, &last);
+
+          if (res == 0) {
+            if (last) {
+              /* simply remove the last AR */
+              responseLen -= optLen;
+              uint16_t arcount = ntohs(responseHeaders->arcount);
+              arcount--;
+              responseHeaders->arcount = htons(arcount);
+            }
+            else {
+              /* Removing an intermediary RR could lead to compression error */
+              if (rewriteResponseWithoutEDNS(response, responseLen, rewrittenResponse) == 0) {
+                response = reinterpret_cast<char*>(rewrittenResponse.data());
+                responseLen = rewrittenResponse.size();
+              }
+              else {
+                warnlog("Error rewriting content");
+              }
+            }
+          }
+        }
 
-        if (putNonBlockingMsgLen(ci.fd, rlen, ds->tcpSendTimeout))
-          writen2WithTimeout(ci.fd, answerbuffer, rlen, ds->tcpSendTimeout);
+        if (putNonBlockingMsgLen(ci.fd, responseLen, ds->tcpSendTimeout))
+          writen2WithTimeout(ci.fd, response, responseLen, ds->tcpSendTimeout);
 
         g_stats.responses++;
+        
+        largerQuery.clear();
+        rewrittenResponse.clear();
       }
     }
     catch(...){}
@@ -392,8 +443,8 @@ bool getMsgLen(int fd, uint16_t* len)
 try
 {
   uint16_t raw;
-  int ret = readn2(fd, &raw, 2);
-  if(ret != 2)
+  int ret = readn2(fd, &raw, sizeof raw);
+  if(ret != sizeof raw)
     return false;
   *len = ntohs(raw);
   return true;
@@ -406,8 +457,8 @@ bool putMsgLen(int fd, uint16_t len)
 try
 {
   uint16_t raw = htons(len);
-  int ret = writen2(fd, &raw, 2);
-  return ret==2;
+  int ret = writen2(fd, &raw, sizeof raw);
+  return ret==sizeof raw;
 }
 catch(...) {
   return false;
index 6c46833a05c541a21b57ccdbcb426b7f6116b53d..a713144baabb60375ba6eab299e9c0a65f3f12fb 100644 (file)
@@ -21,6 +21,7 @@
 */
 
 #include "dnsdist.hh"
+#include "dnsdist-ecs.hh"
 #include "sstuff.hh"
 #include "misc.hh"
 #include <netinet/tcp.h>
@@ -108,12 +109,12 @@ int g_tcpRecvTimeout{2};
 int g_tcpSendTimeout{2};
 
 bool g_truncateTC{1};
-void truncateTC(const char* packet, unsigned int* len)
+static void truncateTC(const char* packet, unsigned int* len)
 try
 {
   unsigned int consumed;
-  DNSName qname(packet, *len, 12, false, 0, 0, &consumed);
-  *len=consumed+12+4;
+  DNSName qname(packet, *len, sizeof(dnsheader), false, 0, 0, &consumed);
+  *len=sizeof(dnsheader)+consumed+DNS_TYPE_SIZE+DNS_CLASS_SIZE;
   struct dnsheader* dh =(struct dnsheader*)packet;
   dh->ancount = dh->arcount = dh->nscount=0;
 }
@@ -139,6 +140,7 @@ struct DelayedPacket
 
 DelayPipe<DelayedPacket> * g_delay = 0;
 
+
 // listens on a dedicated socket, lobs answers from downstream servers to original requestors
 void* responderThread(std::shared_ptr<DownstreamState> state)
 {
@@ -146,11 +148,15 @@ void* responderThread(std::shared_ptr<DownstreamState> state)
   const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
   const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
   const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
+  vector<uint8_t> rewrittenResponse;
   
   struct dnsheader* dh = (struct dnsheader*)packet;
   int len;
   for(;;) {
     len = recv(state->fd, packet, sizeof(packet), 0);
+    const char * response = packet;
+    size_t responseLen = len;
+
     if(len < (signed)sizeof(dnsheader))
       continue;
 
@@ -179,17 +185,46 @@ void* responderThread(std::shared_ptr<DownstreamState> state)
     *flags |= origFlags;
 
     dh->id = ids->origID;
+
+    if (ids->ednsAdded) {
+      const char * optStart = NULL;
+      size_t optLen = 0;
+      bool last = false;
+
+      int res = locateEDNSOptRR(packet, len, &optStart, &optLen, &last);
+
+      if (res == 0) {
+        if (last) {
+          /* simply remove the last AR */
+          responseLen -= optLen;
+          uint16_t arcount = ntohs(dh->arcount);
+          arcount--;
+          dh->arcount = htons(arcount);
+        }
+        else {
+          /* Removing an intermediary RR could lead to compression error */
+          if (rewriteResponseWithoutEDNS(packet, len, rewrittenResponse) == 0) {
+            response = reinterpret_cast<char*>(rewrittenResponse.data());
+            responseLen = rewrittenResponse.size();
+          }
+          else {
+            warnlog("Error rewriting content");
+          }
+        }
+      }
+    }
+
     g_stats.responses++;
 
     if(ids->delayMsec && g_delay) {
-      DelayedPacket dp{origFD, string(packet,len), ids->origRemote, ids->origDest};
+      DelayedPacket dp{origFD, string(response,responseLen), ids->origRemote, ids->origDest};
       g_delay->submit(dp, ids->delayMsec);
     }
     else {
       if(ids->origDest.sin4.sin_family == 0)
-       sendto(origFD, packet, len, 0, (struct sockaddr*)&ids->origRemote, ids->origRemote.getSocklen());
+       sendto(origFD, response, responseLen, 0, (struct sockaddr*)&ids->origRemote, ids->origRemote.getSocklen());
       else
-       sendfromto(origFD, packet, len, 0, ids->origDest, ids->origRemote);
+       sendfromto(origFD, response, responseLen, 0, ids->origDest, ids->origRemote);
     }
     double udiff = ids->sentTime.udiff();
     vinfolog("Got answer from %s, relayed to %s, took %f usec", state->remote.toStringWithPort(), ids->origRemote.toStringWithPort(), udiff);
@@ -200,7 +235,7 @@ void* responderThread(std::shared_ptr<DownstreamState> state)
       std::lock_guard<std::mutex> lock(g_rings.respMutex);
       g_rings.respRing.push_back({ts, ids->origRemote, ids->qname, ids->qtype, (uint8_t)dh->rcode, (unsigned int)udiff, (unsigned int)len});
     }
-    if(dh->rcode == 2)
+    if(dh->rcode == RCode::ServFail)
       g_stats.servfailResponses++;
     state->latencyUsec = (127.0 * state->latencyUsec / 128.0) + udiff/128.0;
 
@@ -222,6 +257,8 @@ void* responderThread(std::shared_ptr<DownstreamState> state)
 
     if (ids->origFD == origFD)
       ids->origFD = -1;
+
+    rewrittenResponse.clear();
   }
   return 0;
 }
@@ -379,26 +416,31 @@ int getEDNSZ(const char* packet, unsigned int len)
 {
   struct dnsheader* dh =(struct dnsheader*)packet;
 
-  if(dh->ancount!=0 && ntohs(dh->arcount)!=1 && dh->nscount!=0)
+  if(ntohs(dh->qdcount) != 1 || dh->ancount!=0 || ntohs(dh->arcount)!=1 || dh->nscount!=0)
     return 0;
-  
+
+  if (len <= sizeof(dnsheader))
+    return 0;
+
   unsigned int consumed;
-  DNSName qname(packet, len, 12, false, 0, 0, &consumed);
-  int pos = consumed + 4;
+  DNSName qname(packet, len, sizeof(dnsheader), false, 0, 0, &consumed);
+  size_t pos = consumed + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
   uint16_t qtype, qclass;
 
-  DNSName aname(packet, len, 12+pos, true, &qtype, &qclass, &consumed);
-  
-  if(qtype!=QType::OPT || 12+pos+consumed+7 >= len)
+  if (len <= (sizeof(dnsheader)+pos))
     return 0;
 
-  uint8_t* z = (uint8_t*)packet+12+pos+consumed+6;
+  DNSName aname(packet, len, sizeof(dnsheader)+pos, true, &qtype, &qclass, &consumed);
+
+  if(qtype!=QType::OPT || sizeof(dnsheader)+pos+consumed+DNS_TYPE_SIZE+DNS_CLASS_SIZE+EDNS_EXTENDED_RCODE_SIZE+EDNS_VERSION_SIZE+1 >= len)
+    return 0;
+
+  uint8_t* z = (uint8_t*)packet+sizeof(dnsheader)+pos+consumed+DNS_TYPE_SIZE+DNS_CLASS_SIZE+EDNS_EXTENDED_RCODE_SIZE+EDNS_VERSION_SIZE;
   return 0x100 * (*z) + *(z+1);
 }
 
-
 // listens to incoming queries, sends out to downstream servers, noting the intended return path 
-void* udpClientThread(ClientState* cs)
+static void* udpClientThread(ClientState* cs)
 try
 {
   ComboAddress remote;
@@ -406,8 +448,7 @@ try
   char packet[1500];
   struct dnsheader* dh = (struct dnsheader*) packet;
   int len;
-
-  string qname;
+  string largerQuery;
   uint16_t qtype;
 
   typedef std::function<bool(ComboAddress, DNSName, uint16_t, dnsheader*)> blockfilter_t;
@@ -427,6 +468,7 @@ try
   auto localDynBlock = g_dynblockNMG.getLocal();
   struct msghdr msgh;
   struct iovec iov;
+  /* used by HarvestDestinationAddress */
   char cbuf[256];
 
   remote.sin6.sin6_family=cs->local.sin6.sin6_family;
@@ -468,7 +510,8 @@ try
       
       const uint16_t * flags = getFlagsFromDNSHeader(dh);
       const uint16_t origFlags = *flags;
-      DNSName qname(packet, len, 12, false, &qtype);
+      unsigned int consumed = 0;
+      DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
       struct timespec now;
       clock_gettime(CLOCK_MONOTONIC, &now);
       {
@@ -581,16 +624,28 @@ try
       ids->origDest.sin4.sin_family=0;
       ids->delayMsec = delayMsec;
       ids->origFlags = origFlags;
+      ids->ednsAdded = false;
       HarvestDestinationAddress(&msgh, &ids->origDest);
 
       dh->id = idOffset;
+
+      if (ss->useECS) {
+        handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerQuery, &(ids->ednsAdded), remote);
+      }
+      
+      if (largerQuery.empty()) {
+        len = send(ss->fd, packet, len, 0);
+      }
+      else {
+        len = send(ss->fd, largerQuery.c_str(), largerQuery.size(), 0);
+        largerQuery.clear();
+      }
       
-      len = send(ss->fd, packet, len, 0);
       if(len < 0) {
        ss->sendErrors++;
        g_stats.downstreamSendErrors++;
       }
-      
+
       vinfolog("Got query from %s, relayed to %s", remote.toStringWithPort(), ss->getName());
     }
     catch(std::exception& e){
@@ -711,7 +766,7 @@ void* maintThread()
 
 string g_key;
 
-void controlClientThread(int fd, ComboAddress client)
+static void controlClientThread(int fd, ComboAddress client)
 try
 {
   SodiumNonce theirs;
@@ -807,7 +862,7 @@ catch(std::exception& e)
 
 
 
-void doClient(ComboAddress server, const std::string& command)
+static void doClient(ComboAddress server, const std::string& command)
 {
   cout<<"Connecting to "<<server.toStringWithPort()<<endl;
   int fd=socket(server.sin4.sin_family, SOCK_STREAM, 0);
@@ -882,7 +937,7 @@ void doClient(ComboAddress server, const std::string& command)
   }
 }
 
-void doConsole()
+static void doConsole()
 {
   set<string> dupper;
   {
@@ -1060,8 +1115,8 @@ char* my_generator(const char* text, int state)
   vector<string> words{"showRules()", "shutdown()", "rmRule(", "mvRule(", "addACL(", "addLocal(", "setServerPolicy(", "setServerPolicyLua(",
       "newServer(", "rmServer(", "showServers()", "show(", "newDNSName(", "newSuffixMatchNode(", "controlSocket(", "topClients(", "showResponseLatency()", 
       "newQPSLimiter(", "makeKey()", "setKey(", "testCrypto()", "addAnyTCRule()", "showServerPolicy()", "setACL(", "showACL()", "addDomainBlock(", 
-      "addPoolRule(", "addQPSLimit(", "topResponses(", "topQueries(", "topRule()", "setDNSSECPool(", "addDelay(",
-      "setMaxUDPOutstanding(", "setMaxTCPClientThreads("};
+      "addPoolRule(", "addQPSLimit(", "topResponses(", "topQueries(", "topRule()", "setDNSSECPool(", "setECSOverride(", "setECSSourcePrefixV4(",
+      "setECSSourcePrefixV6(", "addDelay(", "setTCPRecvTimeout(", "setTCPSendTimeout(", "setMaxTCPClientThreads(", "setMaxUDPOutstanding(" };
   static int s_counter=0;
   int counter=0;
   if(!state)
index f3b71fd319752174f6dce6e65c5f1705a50d687a..c48fda1c5e3c116aa1b5d958a255d5c2a03c40b4 100644 (file)
@@ -195,6 +195,7 @@ struct IDState
   uint16_t origID;                                            // 2
   uint16_t origFlags;                                         // 2
   int delayMsec;
+  bool ednsAdded{false};
 };
 
 struct Rings {
@@ -298,6 +299,7 @@ struct DownstreamState
   enum class Availability { Up, Down, Auto} availability{Availability::Auto};
   bool mustResolve;
   bool upStatus{false};
+  bool useECS{false};
   bool isUp() const
   {
     if(availability == Availability::Down)
@@ -373,6 +375,16 @@ struct CarbonConfig
   unsigned int interval{30};
 };
 
+enum ednsHeaderFlags {
+  EDNS_HEADER_FLAG_NONE = 0,
+  EDNS_HEADER_FLAG_DO = 32768
+};
+
+enum ednsOptionCodes {
+  EDNS0_OPTION_CODE_NONE = 0,
+  EDNS0_OPTION_CODE_ECS = 8,
+};
+
 extern GlobalStateHolder<CarbonConfig> g_carbon;
 extern GlobalStateHolder<ServerPolicy> g_policy;
 extern GlobalStateHolder<servers_t> g_dstates;
@@ -390,6 +402,10 @@ extern int g_tcpSendTimeout;
 extern uint16_t g_maxOutstanding;
 extern std::atomic<bool> g_configurationDone;
 extern std::atomic<uint64_t> g_maxTCPClientThreads;
+extern uint16_t g_ECSSourcePrefixV4;
+extern uint16_t g_ECSSourcePrefixV6;
+extern bool g_ECSOverride;
+
 struct dnsheader;
 
 void controlThread(int fd, ComboAddress local);
@@ -403,6 +419,7 @@ std::shared_ptr<DownstreamState> wrandom(const NumberedServerVector& servers, co
 std::shared_ptr<DownstreamState> whashed(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh);
 std::shared_ptr<DownstreamState> roundrobin(const NumberedServerVector& servers, const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh);
 int getEDNSZ(const char* packet, unsigned int len);
+uint16_t getEDNSOptionCode(const char * packet, size_t len);
 void dnsdistWebserverThread(int sock, const ComboAddress& local, const string& password);
 bool getMsgLen(int fd, uint16_t* len);
 bool putMsgLen(int fd, uint16_t len);
index 1fbb583cb92b05b85040bec6255b0bd7fbc8c0d1..3915cce44774ad8d4478a699a21ae5cbf93432d7 100644 (file)
@@ -23,12 +23,18 @@ EXTRA_DIST=dnslabeltext.rl \
           .version \
           contrib
 
+if UNIT_TESTS
+bin_PROGRAMS = dnsdist testrunner
+else
 bin_PROGRAMS = dnsdist
+endif
+
 dnsdist_SOURCES = \
        base64.hh \
        dns.hh \
        dnsdist.cc dnsdist.hh \
        dnsdist-carbon.cc \
+       dnsdist-ecs.cc dnsdist-ecs.hh \
        dnsdist-lua.cc \
        dnsdist-lua2.cc \
        dnsdist-rings.cc \
@@ -36,10 +42,11 @@ dnsdist_SOURCES = \
        dnsdist-web.cc \
        dnslabeltext.cc \
        dnsname.cc dnsname.hh \
-       dnsparser.hh \
+       dnsparser.hh dnsparser.cc \
        dnsrulactions.hh \
        dnswriter.cc dnswriter.hh \
        dolog.hh \
+       ednssubnet.cc ednssubnet.hh \
        iputils.cc iputils.hh \
        lock.hh \
        misc.cc misc.hh \
@@ -73,3 +80,33 @@ dnsdist_LDADD = \
        $(LIBSODIUM_LIBS) \
        $(SANITIZER_FLAGS)
 
+
+testrunner_SOURCES = \
+       dns.hh \
+       test-dnsdist_cc.cc dnsdist.hh \
+       dnsdist-ecs.cc dnsdist-ecs.hh \
+       dnslabeltext.cc \
+       dnsname.cc dnsname.hh \
+       dnsparser.hh dnsparser.cc \
+       dnswriter.cc dnswriter.hh \
+       dolog.hh \
+       ednssubnet.cc ednssubnet.hh \
+       iputils.cc iputils.hh \
+       misc.cc misc.hh \
+       namespaces.hh \
+       pdnsexception.hh \
+       qtype.cc qtype.hh \
+       sholder.hh \
+       sstuff.hh \
+       testrunner.cc
+
+testrunner_LDFLAGS = \
+       $(AM_LDFLAGS) \
+       $(PROGRAM_LDFLAGS) \
+       $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) \
+       -pthread
+
+testrunner_LDADD = \
+       $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) \
+       $(RT_LIBS)
+
index 026c9f4ba3ebfe1affc9a98c15d26dd887d5eaa9..9fcb3957db0911dd726a4c7713071493e211f00a 100644 (file)
@@ -11,6 +11,7 @@ PDNS_CHECK_READLINE([mandatory])
 PDNS_CHECK_CLOCK_GETTIME
 BOOST_REQUIRE([1.35])
 BOOST_FOREACH
+PDNS_ENABLE_UNIT_TESTS
 AC_SUBST([YAHTTP_CFLAGS], ['-I$(top_srcdir)/ext/yahttp'])
 AC_SUBST([YAHTTP_LIBS], ['-L$(top_builddir)/ext/yahttp/yahttp -lyahttp'])
 DNSDIST_LUA
diff --git a/pdns/dnsdistdist/dnsdist-ecs.cc b/pdns/dnsdistdist/dnsdist-ecs.cc
new file mode 120000 (symlink)
index 0000000..9bf0156
--- /dev/null
@@ -0,0 +1 @@
+../dnsdist-ecs.cc
\ No newline at end of file
diff --git a/pdns/dnsdistdist/dnsdist-ecs.hh b/pdns/dnsdistdist/dnsdist-ecs.hh
new file mode 120000 (symlink)
index 0000000..bbd2156
--- /dev/null
@@ -0,0 +1 @@
+../dnsdist-ecs.hh
\ No newline at end of file
diff --git a/pdns/dnsdistdist/dnsparser.cc b/pdns/dnsdistdist/dnsparser.cc
new file mode 120000 (symlink)
index 0000000..1352518
--- /dev/null
@@ -0,0 +1 @@
+../dnsparser.cc
\ No newline at end of file
diff --git a/pdns/dnsdistdist/ednssubnet.cc b/pdns/dnsdistdist/ednssubnet.cc
new file mode 120000 (symlink)
index 0000000..95d9f89
--- /dev/null
@@ -0,0 +1 @@
+../ednssubnet.cc
\ No newline at end of file
diff --git a/pdns/dnsdistdist/ednssubnet.hh b/pdns/dnsdistdist/ednssubnet.hh
new file mode 120000 (symlink)
index 0000000..43fefd7
--- /dev/null
@@ -0,0 +1 @@
+../ednssubnet.hh
\ No newline at end of file
diff --git a/pdns/dnsdistdist/m4/pdns_enable_unit_tests.m4 b/pdns/dnsdistdist/m4/pdns_enable_unit_tests.m4
new file mode 120000 (symlink)
index 0000000..469b730
--- /dev/null
@@ -0,0 +1 @@
+../../../m4/pdns_enable_unit_tests.m4
\ No newline at end of file
diff --git a/pdns/dnsdistdist/test-dnsdist_cc.cc b/pdns/dnsdistdist/test-dnsdist_cc.cc
new file mode 120000 (symlink)
index 0000000..ae06da2
--- /dev/null
@@ -0,0 +1 @@
+../test-dnsdist_cc.cc
\ No newline at end of file
diff --git a/pdns/dnsdistdist/testrunner.cc b/pdns/dnsdistdist/testrunner.cc
new file mode 100644 (file)
index 0000000..bce387f
--- /dev/null
@@ -0,0 +1,10 @@
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_MAIN
+#define BOOST_TEST_MODULE unit
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+#include <boost/test/unit_test.hpp>
+
+
index c9a4b7538776bba78a5cf441010b7bc053cd3255..5712b774b44009a28b8294522f92a3f5fa773bb0 100644 (file)
@@ -90,7 +90,7 @@ public:
   }
   bool matches(const ComboAddress& remote, const DNSName& qname, uint16_t qtype, dnsheader* dh, int len) const override
   {
-    return dh->cd || (getEDNSZ((const char*)dh, len) & 32768);    // turns out dig sets ad by default..
+    return dh->cd || (getEDNSZ((const char*)dh, len) & EDNS_HEADER_FLAG_DO);    // turns out dig sets ad by default..
   }
 
   string toString() const override
diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc
new file mode 100644 (file)
index 0000000..fd2eddf
--- /dev/null
@@ -0,0 +1,265 @@
+
+/*
+    PowerDNS Versatile Database Driven Nameserver
+    Copyright (C) 2013 - 2015  PowerDNS.COM BV
+
+    This program is free software; you can redistribute it and/or modify
+    it under the terms of the GNU General Public License version 2
+    as published by the Free Software Foundation
+
+    Additionally, the license of this program contains a special
+    exception which allows to distribute the program in binary form when
+    it is linked against OpenSSL.
+
+    This program is distributed in the hope that it will be useful,
+    but WITHOUT ANY WARRANTY; without even the implied warranty of
+    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+    GNU General Public License for more details.
+
+    You should have received a copy of the GNU General Public License
+    along with this program; if not, write to the Free Software
+    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+*/
+
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+
+#include <boost/test/unit_test.hpp>
+
+#include "dnsdist.hh"
+#include "dnsdist-ecs.hh"
+#include "dolog.hh"
+#include "dnsname.hh"
+#include "dnsparser.hh"
+#include "dnswriter.hh"
+#include "ednssubnet.hh"
+#include <unistd.h>
+
+BOOST_AUTO_TEST_SUITE(dnsdist_cc)
+
+bool g_console{true};
+bool g_verbose{true};
+
+static void validateQuery(const char * packet, size_t packetSize)
+{
+  MOADNSParser mdp(packet, packetSize);
+  
+  BOOST_CHECK_EQUAL(mdp.d_qname.toString(), "www.powerdns.com.");
+
+  BOOST_CHECK_EQUAL(mdp.d_header.qdcount, 1);
+  BOOST_CHECK_EQUAL(mdp.d_header.ancount, 0);
+  BOOST_CHECK_EQUAL(mdp.d_header.nscount, 0);
+  BOOST_CHECK_EQUAL(mdp.d_header.arcount, 1);
+}
+
+BOOST_AUTO_TEST_CASE(addECSWithoutEDNS)
+{
+  string largerPacket;
+  bool ednsAdded = false;
+  ComboAddress remote;
+  DNSName name("www.powerdns.com.");
+
+  vector<uint8_t> query;
+  DNSPacketWriter pw(query, name, QType::OPT, QClass::IN, 0);
+  pw.getHeader()->rd = 1;
+  int len = query.size();
+    
+  /* large enough packet */
+  char packet[1500];
+  memcpy(packet, query.data(), query.size());
+
+  unsigned int consumed = 0;
+  uint16_t qtype;
+  DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::OPT);
+  
+  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, remote);
+  BOOST_CHECK((size_t) len > query.size());
+  BOOST_CHECK_EQUAL(largerPacket.size(), 0);
+  BOOST_CHECK_EQUAL(ednsAdded, true);
+  validateQuery(packet, len);
+  
+  /* not large enought packet */
+  consumed = 0;
+  len = query.size();
+  qname = DNSName((char*) query.data(), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::OPT);
+  
+  handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, remote);
+  BOOST_CHECK_EQUAL((size_t) len, query.size());
+  BOOST_CHECK(largerPacket.size() > query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, true);
+  validateQuery(largerPacket.c_str(), largerPacket.size());
+}
+
+BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) {
+  string largerPacket;
+  bool ednsAdded = false;
+  ComboAddress remote;
+  DNSName name("www.powerdns.com");
+
+  vector<uint8_t> query;
+  DNSPacketWriter pw(query, name, QType::OPT, QClass::IN, 0);
+  pw.getHeader()->rd = 1;
+  pw.addOpt(512, 0, 0);
+  pw.commit();
+  int len = query.size();
+    
+  /* large enough packet */
+  char packet[1500];
+  memcpy(packet, query.data(), query.size());
+
+  unsigned int consumed = 0;
+  uint16_t qtype;
+  DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::OPT);
+  
+  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, remote);
+  BOOST_CHECK((size_t) len > query.size());
+  BOOST_CHECK_EQUAL(largerPacket.size(), 0);
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  validateQuery(packet, len);
+  
+  /* not large enought packet */
+  consumed = 0;
+  len = query.size();
+  qname = DNSName((char*) query.data(), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::OPT);
+  
+  handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, remote);
+  BOOST_CHECK_EQUAL((size_t) len, query.size());
+  BOOST_CHECK(largerPacket.size() > query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  validateQuery(largerPacket.c_str(), largerPacket.size());
+}
+
+BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) {
+  string largerPacket;
+  bool ednsAdded = false;
+  ComboAddress remote("192.168.1.25");
+  DNSName name("www.powerdns.com");
+  ComboAddress origRemote("127.0.0.1");
+
+  vector<uint8_t> query;
+  DNSPacketWriter pw(query, name, QType::OPT, QClass::IN, 0);
+  pw.getHeader()->rd = 1;
+  EDNSSubnetOpts ecsOpts;
+  ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4);
+  string origECSOption = makeEDNSSubnetOptsString(ecsOpts);
+  DNSPacketWriter::optvect_t opts;
+  opts.push_back(make_pair(EDNS0_OPTION_CODE_ECS, origECSOption));
+  pw.addOpt(512, 0, 0, opts);
+  pw.commit();
+  int len = query.size();
+    
+  /* large enough packet */
+  char packet[1500];
+  memcpy(packet, query.data(), query.size());
+
+  unsigned int consumed = 0;
+  uint16_t qtype;
+  DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::OPT);
+
+  g_ECSOverride = true;
+  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, remote);
+  BOOST_CHECK_EQUAL((size_t) len, query.size());
+  BOOST_CHECK_EQUAL(largerPacket.size(), 0);
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  validateQuery(packet, len);
+}
+
+BOOST_AUTO_TEST_CASE(replaceECSWithSmaller) {
+  string largerPacket;
+  bool ednsAdded = false;
+  ComboAddress remote("192.168.1.25");
+  DNSName name("www.powerdns.com");
+  ComboAddress origRemote("127.0.0.1");
+
+  vector<uint8_t> query;
+  DNSPacketWriter pw(query, name, QType::OPT, QClass::IN, 0);
+  pw.getHeader()->rd = 1;
+  EDNSSubnetOpts ecsOpts;
+  ecsOpts.source = Netmask(origRemote, 32);
+  string origECSOption = makeEDNSSubnetOptsString(ecsOpts);
+  DNSPacketWriter::optvect_t opts;
+  opts.push_back(make_pair(EDNS0_OPTION_CODE_ECS, origECSOption));
+  pw.addOpt(512, 0, 0, opts);
+  pw.commit();
+  int len = query.size();
+    
+  /* large enough packet */
+  char packet[1500];
+  memcpy(packet, query.data(), query.size());
+
+  unsigned int consumed = 0;
+  uint16_t qtype;
+  DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::OPT);
+
+  g_ECSOverride = true;
+  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, remote);
+  BOOST_CHECK((size_t) len < query.size());
+  BOOST_CHECK_EQUAL(largerPacket.size(), 0);
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  validateQuery(packet, len);
+}
+
+BOOST_AUTO_TEST_CASE(replaceECSWithLarger) {
+  string largerPacket;
+  bool ednsAdded = false;
+  ComboAddress remote("192.168.1.25");
+  DNSName name("www.powerdns.com");
+  ComboAddress origRemote("127.0.0.1");
+
+  vector<uint8_t> query;
+  DNSPacketWriter pw(query, name, QType::OPT, QClass::IN, 0);
+  pw.getHeader()->rd = 1;
+  EDNSSubnetOpts ecsOpts;
+  ecsOpts.source = Netmask(origRemote, 8);
+  string origECSOption = makeEDNSSubnetOptsString(ecsOpts);
+  DNSPacketWriter::optvect_t opts;
+  opts.push_back(make_pair(EDNS0_OPTION_CODE_ECS, origECSOption));
+  pw.addOpt(512, 0, 0, opts);
+  pw.commit();
+  int len = query.size();
+    
+  /* large enough packet */
+  char packet[1500];
+  memcpy(packet, query.data(), query.size());
+
+  unsigned int consumed = 0;
+  uint16_t qtype;
+  DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::OPT);
+
+  g_ECSOverride = true;
+  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, remote);
+  BOOST_CHECK((size_t) len > query.size());
+  BOOST_CHECK_EQUAL(largerPacket.size(), 0);
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  validateQuery(packet, len);
+
+  /* not large enought packet */
+  consumed = 0;
+  len = query.size();
+  qname = DNSName((char*) query.data(), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::OPT);
+  
+  g_ECSOverride = true;
+  handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, remote);
+  BOOST_CHECK_EQUAL((size_t) len, query.size());
+  BOOST_CHECK(largerPacket.size() > query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  validateQuery(largerPacket.c_str(), largerPacket.size());
+}
+
+BOOST_AUTO_TEST_SUITE_END();
diff --git a/regression-tests.dnsdist/.gitignore b/regression-tests.dnsdist/.gitignore
new file mode 100644 (file)
index 0000000..1cc50b4
--- /dev/null
@@ -0,0 +1,7 @@
+/.venv
+/*.xml
+/*.pid
+/*.pyc
+dnsdist_ecs*.conf
+.dnsdist_history
+.history
diff --git a/regression-tests.dnsdist/clientsubnetoption.py b/regression-tests.dnsdist/clientsubnetoption.py
new file mode 100644 (file)
index 0000000..4097f4c
--- /dev/null
@@ -0,0 +1,301 @@
+#!/usr/bin/env python
+#
+# Copyright (c) 2012 OpenDNS, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#    * Redistributions of source code must retain the above copyright
+#      notice, this list of conditions and the following disclaimer.
+#    * Redistributions in binary form must reproduce the above copyright
+#      notice, this list of conditions and the following disclaimer in the
+#      documentation and/or other materials provided with the distribution.
+#    * Neither the name of the OpenDNS nor the names of its contributors may be
+#      used to endorse or promote products derived from this software without
+#      specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL OPENDNS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
+# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
+# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+""" Class to implement draft-ietf-dnsop-edns-client-subnet (previously known as
+draft-vandergaast-edns-client-subnet.
+
+The contained class supports both IPv4 and IPv6 addresses.
+Requirements:
+  dnspython (http://www.dnspython.org/)
+"""
+from __future__ import print_function
+from __future__ import division
+
+import socket
+import struct
+import dns
+import dns.edns
+import dns.flags
+import dns.message
+import dns.query
+
+__author__ = "bhartvigsen@opendns.com (Brian Hartvigsen)"
+__version__ = "2.0.0"
+
+ASSIGNED_OPTION_CODE = 0x0008
+DRAFT_OPTION_CODE = 0x50FA
+
+FAMILY_IPV4 = 1
+FAMILY_IPV6 = 2
+SUPPORTED_FAMILIES = (FAMILY_IPV4, FAMILY_IPV6)
+
+
+class ClientSubnetOption(dns.edns.Option):
+    """Implementation of draft-vandergaast-edns-client-subnet-01.
+
+    Attributes:
+        family: An integer inidicating which address family is being sent
+        ip: IP address in integer notation
+        mask: An integer representing the number of relevant bits being sent
+        scope: An integer representing the number of significant bits used by
+            the authoritative server.
+    """
+
+    def __init__(self, ip, bits=24, scope=0, option=ASSIGNED_OPTION_CODE):
+        super(ClientSubnetOption, self).__init__(option)
+
+        n = None
+        f = None
+
+        for family in (socket.AF_INET, socket.AF_INET6):
+            try:
+                n = socket.inet_pton(family, ip)
+                if family == socket.AF_INET6:
+                    f = FAMILY_IPV6
+                    hi, lo = struct.unpack('!QQ', n)
+                    ip = hi << 64 | lo
+                elif family == socket.AF_INET:
+                    f = FAMILY_IPV4
+                    ip = struct.unpack('!L', n)[0]
+            except Exception:
+                pass
+
+        if n is None:
+            raise Exception("%s is an invalid ip" % ip)
+
+        self.family = f
+        self.ip = ip
+        self.mask = bits
+        self.scope = scope
+        self.option = option
+
+        if self.family == FAMILY_IPV4 and self.mask > 32:
+            raise Exception("32 bits is the max for IPv4 (%d)" % bits)
+        if self.family == FAMILY_IPV6 and self.mask > 128:
+            raise Exception("128 bits is the max for IPv6 (%d)" % bits)
+
+    def calculate_ip(self):
+        """Calculates the relevant ip address based on the network mask.
+
+        Calculates the relevant bits of the IP address based on network mask.
+        Sizes up to the nearest octet for use with wire format.
+
+        Returns:
+            An integer of only the significant bits sized up to the nearest
+            octect.
+        """
+
+        if self.family == FAMILY_IPV4:
+            bits = 32
+        elif self.family == FAMILY_IPV6:
+            bits = 128
+
+        ip = self.ip >> bits - self.mask
+
+        if (self.mask % 8 != 0):
+            ip = ip << 8 - (self.mask % 8)
+
+        return ip
+
+    def is_draft(self):
+        """" Determines whether this instance is using the draft option code """
+        return self.option == DRAFT_OPTION_CODE
+
+    def to_wire(self, file):
+        """Create EDNS packet as definied in draft-vandergaast-edns-client-subnet-01."""
+
+        ip = self.calculate_ip()
+
+        mask_bits = self.mask
+        if mask_bits % 8 != 0:
+                mask_bits += 8 - (self.mask % 8)
+
+        if self.family == FAMILY_IPV4:
+            test = struct.pack("!L", ip)
+        elif self.family == FAMILY_IPV6:
+            test = struct.pack("!QQ", ip >> 64, ip & (2 ** 64 - 1))
+        test = test[-(mask_bits // 8):]
+
+        format = "!HBB%ds" % (mask_bits // 8)
+        data = struct.pack(format, self.family, self.mask, 0, test)
+        file.write(data)
+
+    def from_wire(cls, otype, wire, current, olen):
+        """Read EDNS packet as defined in draft-vandergaast-edns-client-subnet-01.
+
+        Returns:
+            An instance of ClientSubnetOption based on the ENDS packet
+        """
+
+        data = wire[current:current + olen]
+        (family, mask, scope) = struct.unpack("!HBB", data[:4])
+
+        c_mask = mask
+        if mask % 8 != 0:
+            c_mask += 8 - (mask % 8)
+
+        ip = struct.unpack_from("!%ds" % (c_mask // 8), data, 4)[0]
+
+        if (family == FAMILY_IPV4):
+            ip = ip + b'\0' * ((32 - c_mask) // 8)
+            ip = socket.inet_ntop(socket.AF_INET, ip)
+        elif (family == FAMILY_IPV6):
+            ip = ip + b'\0' * ((128 - c_mask) // 8)
+            ip = socket.inet_ntop(socket.AF_INET6, ip)
+        else:
+            raise Exception("Returned a family other then IPv4 or IPv6")
+
+        return cls(ip, mask, scope, otype)
+
+    from_wire = classmethod(from_wire)
+
+    def __repr__(self):
+        if self.family == FAMILY_IPV4:
+            ip = socket.inet_ntop(socket.AF_INET, struct.pack('!L', self.ip))
+        elif self.family == FAMILY_IPV6:
+            ip = socket.inet_ntop(socket.AF_INET6,
+                                  struct.pack('!QQ',
+                                              self.ip >> 64,
+                                              self.ip & (2 ** 64 - 1)))
+
+        return "%s(%s, %s, %s)" % (
+            self.__class__.__name__,
+            ip,
+            self.mask,
+            self.scope
+        )
+
+    def __eq__(self, other):
+        """Rich comparison method for equality.
+
+        Two ClientSubnetOptions are equal if their relevant ip bits, mask, and
+        family are identical. We ignore scope since generally we want to
+        compare questions to responses and that bit is only relevant when
+        determining caching behavior.
+
+        Returns:
+            boolean
+        """
+
+        if not isinstance(other, ClientSubnetOption):
+            return False
+        if self.calculate_ip() != other.calculate_ip():
+            return False
+        if self.mask != other.mask:
+            return False
+        if self.family != other.family:
+            return False
+        return True
+
+    def __ne__(self, other):
+        """Rich comparison method for inequality.
+
+        See notes for __eq__()
+
+        Returns:
+            boolean
+        """
+        return not self.__eq__(other)
+
+
+dns.edns._type_to_class[DRAFT_OPTION_CODE] = ClientSubnetOption
+dns.edns._type_to_class[ASSIGNED_OPTION_CODE] = ClientSubnetOption
+
+if __name__ == "__main__":
+    import argparse
+    import sys
+
+    def CheckForClientSubnetOption(addr, args, option_code=ASSIGNED_OPTION_CODE):
+        print("Testing for edns-clientsubnet using option code", hex(option_code), file=sys.stderr)
+        cso = ClientSubnetOption(args.subnet, args.mask, option=option_code)
+        message = dns.message.make_query(args.rr, args.type)
+        # Tested authoritative servers seem to use the last code in cases
+        # where they support both. We make the official code last to allow
+        # us to check for support of both draft and official
+        message.use_edns(options=[cso])
+
+        try:
+            r = dns.query.udp(message, addr, timeout=args.timeout)
+            if r.flags & dns.flags.TC:
+                r = dns.query.tcp(message, addr, timeout=args.timeout)
+        except dns.exception.Timeout:
+            print("Timeout: No answer received from %s\n" % args.nameserver, file=sys.stderr)
+            sys.exit(3)
+
+        error = False
+        found = False
+        for options in r.options:
+            # Have not run into anyone who passes back both codes yet
+            # but just in case, we want to check all possible options
+            if isinstance(options, ClientSubnetOption):
+                found = True
+                print("Found ClientSubnetOption...", end=None, file=sys.stderr)
+                if not cso.family == options.family:
+                    error = True
+                    print("\nFailed: returned family (%d) is different from the passed family (%d)" % (options.family, cso.family), file=sys.stderr)
+                if not cso.calculate_ip() == options.calculate_ip():
+                    error = True
+                    print("\nFailed: returned ip (%s) is different from the passed ip (%s)." % (options.calculate_ip(), cso.calculate_ip()), file=sys.stderr)
+                if not options.mask == cso.mask:
+                    error = True
+                    print("\nFailed: returned mask bits (%d) is different from the passed mask bits (%d)" % (options.mask, cso.mask), file=sys.stderr)
+                if not options.scope != 0:
+                    print("\nWarning: scope indicates edns-clientsubnet data is not used", file=sys.stderr)
+                if options.is_draft():
+                    print("\nWarning: detected support for edns-clientsubnet draft code", file=sys.stderr)
+
+        if found and not error:
+            print("Success", file=sys.stderr)
+        elif found:
+            print("Failed: See error messages above", file=sys.stderr)
+        else:
+            print("Failed: No ClientSubnetOption returned", file=sys.stderr)
+
+    parser = argparse.ArgumentParser(description='draft-vandergaast-edns-client-subnet-01 tester')
+    parser.add_argument('nameserver', help='The nameserver to test')
+    parser.add_argument('rr', help='DNS record that should return an EDNS enabled response')
+    parser.add_argument('-s', '--subnet', help='Specifies an IP to pass as the client subnet.', default='192.0.2.0')
+    parser.add_argument('-m', '--mask', type=int, help='CIDR mask to use for subnet')
+    parser.add_argument('--timeout', type=int, help='Set the timeout for query to TIMEOUT seconds, default=10', default=10)
+    parser.add_argument('-t', '--type', help='DNS query type, default=A', default='A')
+    args = parser.parse_args()
+
+    if not args.mask:
+        if ':' in args.subnet:
+            args.mask = 48
+        else:
+            args.mask = 24
+
+    try:
+        addr = socket.gethostbyname(args.nameserver)
+    except socket.gaierror:
+        print("Unable to resolve %s\n" % args.nameserver, file=sys.stderr)
+        sys.exit(3)
+
+    CheckForClientSubnetOption(addr, args, DRAFT_OPTION_CODE)
+    print("", file=sys.stderr)
+    CheckForClientSubnetOption(addr, args, ASSIGNED_OPTION_CODE)
diff --git a/regression-tests.dnsdist/dnsdist.conf b/regression-tests.dnsdist/dnsdist.conf
new file mode 100644 (file)
index 0000000..aa49752
--- /dev/null
@@ -0,0 +1,10 @@
+    truncateTC(true)
+    block=newDNSName("powerdns.org.")
+    function blockFilter(remote, qname, qtype, dh)
+        if(qname:isPartOf(block))
+        then
+            print("Blocking *.powerdns.org")
+            return true
+        end
+        return false
+    end
diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py
new file mode 100644 (file)
index 0000000..8b62a08
--- /dev/null
@@ -0,0 +1,205 @@
+#!/usr/bin/env python2
+
+import clientsubnetoption
+import dns
+import Queue
+import os
+import socket
+import struct
+import subprocess
+import sys
+import threading
+import time
+import unittest
+
+class DNSDistTest(unittest.TestCase):
+    """
+    Set up a dnsdist instance and responder threads.
+    Queries sent to dnsdist are relayed to the responder threads,
+    who reply with the response provided by the tests themselves
+    on a queue. Responder threads also queue the queries received
+    from dnsdist on a separate queue, allowing the tests to check
+    that the queries sent from dnsdist were as expected.
+    """
+    _dnsDistPort = 5340
+    _testServerPort = 5350
+    _dnsdistcmd = (os.environ['DNSDISTBIN'] + " -C dnsdist.conf --acl 127.0.0.1/32 -l 127.0.0.1:" + str(_dnsDistPort) + " 127.0.0.1:" + str(_testServerPort)).split()
+    _toResponderQueue = Queue.Queue()
+    _fromResponderQueue = Queue.Queue()
+    _dnsdist = None
+
+    @classmethod
+    def startResponders(cls):
+        print("Launching responders..")
+        cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[])
+        cls._UDPResponder.setDaemon(True)
+        cls._UDPResponder.start()
+        cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[])
+        cls._TCPResponder.setDaemon(True)
+        cls._TCPResponder.start()
+
+    @classmethod
+    def startDNSDist(cls, shutUp=True):
+        print("Launching dnsdist..")
+        print(' '.join(cls._dnsdistcmd))
+        if shutUp:
+            with open(os.devnull, 'w') as fdDevNull:
+                cls._dnsdist = subprocess.Popen(cls._dnsdistcmd, close_fds=True, stdout=fdDevNull, stderr=fdDevNull)
+        else:
+            cls._dnsdist = subprocess.Popen(cls._dnsdistcmd, close_fds=True)
+
+        time.sleep(1)
+
+        if cls._dnsdist.poll() is not None:
+            cls._dnsdist.terminate()
+            cls._dnsdist.wait()
+            sys.exit(cls._dnsdist.returncode)
+
+    @classmethod
+    def setUpSockets(cls):
+        print("Setting up UDP socket..")
+        cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
+
+    @classmethod
+    def setUpClass(cls):
+
+        cls.startResponders()
+        cls.startDNSDist()
+        cls.setUpSockets()
+
+        print("Launching tests..")
+
+    @classmethod
+    def tearDownClass(cls):
+        if cls._dnsdist:
+            cls._dnsdist.terminate()
+            cls._dnsdist.wait()
+
+    @classmethod
+    def UDPResponder(cls):
+        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+        sock.bind(("127.0.0.1", cls._testServerPort))
+        while True:
+            data, addr = sock.recvfrom(4096)
+            request = dns.message.from_wire(data)
+            if len(request.question) != 1:
+                print("Skipping query with question count %d" % (len(request.question)))
+                continue
+            if str(request.question[0].name).endswith('tests.powerdns.com.') and not cls._toResponderQueue.empty():
+                response = cls._toResponderQueue.get()
+                response.id = request.id
+                cls._fromResponderQueue.put(request)
+            else:
+                # unexpected query, or health check
+                response = dns.message.make_response(request)
+                rrset = dns.rrset.from_text(request.question[0].name,
+                                            3600,
+                                            request.question[0].rdclass,
+                                            request.question[0].rdtype,
+                                            '127.0.0.1')
+                response.answer.append(rrset)
+
+            sock.sendto(response.to_wire(), addr)
+        sock.close()
+
+    @classmethod
+    def TCPResponder(cls):
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+        try:
+            sock.bind(("127.0.0.1", cls._testServerPort))
+        except socket.error as e:
+            print("Error binding in the TCP responder: %s" % str(e))
+            sys.exit(1)
+
+        sock.listen(100)
+        while True:
+            (conn, address) = sock.accept()
+            data = conn.recv(2)
+            (datalen,) = struct.unpack("!H", data)
+            data = conn.recv(datalen)
+            request = dns.message.from_wire(data)
+            if len(request.question) != 1:
+                print("Skipping query with question count %d" % (len(request.question)))
+                continue
+            if str(request.question[0].name).endswith('tests.powerdns.com.') and not cls._toResponderQueue.empty():
+                response = cls._toResponderQueue.get()
+                response.id = request.id
+                cls._fromResponderQueue.put(request)
+            else:
+                # unexpected query, or health check
+                response = dns.message.make_response(request)
+                rrset = dns.rrset.from_text(request.question[0].name,
+                                            3600,
+                                            request.question[0].rdclass,
+                                            request.question[0].rdtype,
+                                            '127.0.0.1')
+                response.answer.append(rrset)
+
+            wire = response.to_wire()
+            conn.send(struct.pack("!H", len(wire)))
+            conn.send(wire)
+            conn.close()
+        sock.close()
+
+    @classmethod
+    def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0):
+        if useQueue:
+            cls._toResponderQueue.put(response)
+
+        if timeout:
+            cls._sock.settimeout(timeout)
+
+        try:
+            cls._sock.send(query.to_wire())
+            data = cls._sock.recv(4096)
+        except socket.timeout as e:
+            data = None
+        finally:
+            if timeout:
+                cls._sock.settimeout(None)
+
+        receivedQuery = None
+        message = None
+        if useQueue and not cls._fromResponderQueue.empty():
+            receivedQuery = cls._fromResponderQueue.get(query)
+        if data:
+            message = dns.message.from_wire(data)
+        return (receivedQuery, message)
+
+    @classmethod
+    def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0):
+        if useQueue:
+            cls._toResponderQueue.put(response)
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.connect(("127.0.0.1", cls._dnsDistPort))
+
+        if timeout:
+            sock.settimeout(timeout)
+
+        try:
+            wire = query.to_wire()
+            sock.send(struct.pack("!H", len(wire)))
+            sock.send(wire)
+            data = sock.recv(2)
+            if data:
+                (datalen,) = struct.unpack("!H", data)
+                data = sock.recv(datalen)
+        except socket.timeout as e:
+            print("Timeout: %s" % (str(e)))
+            data = None
+        except socket.error as e:
+            print("Network error: %s" % (str(e)))
+            data = None
+        finally:
+            sock.close()
+
+        receivedQuery = None
+        message = None
+        if useQueue and not cls._fromResponderQueue.empty():
+            receivedQuery = cls._fromResponderQueue.get(query)
+        if data:
+            message = dns.message.from_wire(data)
+        return (receivedQuery, message)
diff --git a/regression-tests.dnsdist/requirements.txt b/regression-tests.dnsdist/requirements.txt
new file mode 100644 (file)
index 0000000..5962a68
--- /dev/null
@@ -0,0 +1,2 @@
+dnspython>=1.11
+nose==1.3.0
diff --git a/regression-tests.dnsdist/runtests b/regression-tests.dnsdist/runtests
new file mode 100755 (executable)
index 0000000..273cd9c
--- /dev/null
@@ -0,0 +1,17 @@
+#!/usr/bin/env bash
+
+if [ ! -d .venv ]; then
+       virtualenv .venv
+fi
+. .venv/bin/activate
+python -V
+pip install -r requirements.txt
+
+if [ -z "${DNSDISTBIN}" ]; then
+    DNSDISTBIN="../pdns/dnsdistdist/dnsdist-*/dnsdist"
+fi
+export DNSDISTBIN
+
+set -e
+set -x
+nosetests --with-xunit
diff --git a/regression-tests.dnsdist/test_Basics.py b/regression-tests.dnsdist/test_Basics.py
new file mode 100644 (file)
index 0000000..ae0de1a
--- /dev/null
@@ -0,0 +1,84 @@
+#!/usr/bin/env python
+import clientsubnetoption
+import dns
+import unittest
+from dnsdisttests import DNSDistTest
+
+class TestBasics(DNSDistTest):
+
+    def testBlockedA(self):
+        """
+        Send an A query for the powerdns.org domain,
+        which is blocked by configuration. We expect
+        no response.
+        """
+        name = 'blockeda.tests.powerdns.org.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False, timeout=2.0)
+        self.assertEquals(receivedResponse, None)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False, timeout=2.0)
+        self.assertEquals(receivedResponse, None)
+
+    def testAWithECS(self):
+        """
+        Send an A query with an ECS value.
+        """
+        name = 'awithecs.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('1.2.3.4')
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso])
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+    def testSimpleA(self):
+        """
+        Send a simple A query without EDNS.
+        """
+        name = 'simplea.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text('simplea.tests.powerdns.com.',
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+
+if __name__ == '__main__':
+    unittest.main()
+    exit(0)
diff --git a/regression-tests.dnsdist/test_EdnsClientSubnet.py b/regression-tests.dnsdist/test_EdnsClientSubnet.py
new file mode 100644 (file)
index 0000000..c041252
--- /dev/null
@@ -0,0 +1,323 @@
+#!/usr/bin/env python
+import clientsubnetoption
+import dns
+import os
+import subprocess
+import time
+import unittest
+from dnsdisttests import DNSDistTest
+
+class TestEdnsClientSubnetNoOverride(DNSDistTest):
+    """
+    DNSdist is configured to add the EDNS0 Client Subnet
+    option, but only if it's not already present in the
+    original query.
+    """
+
+    _dnsDistPort = 5340
+    _config_template = """
+    truncateTC(true)
+    block=newDNSName("powerdns.org.")
+    function blockFilter(remote, qname, qtype, dh)
+        if(qname:isPartOf(block))
+        then
+            print("Blocking *.powerdns.org")
+            return true
+        end
+        return false
+    end
+    newServer{address="127.0.0.1:%s", useClientSubnet=true}
+    """
+
+    _dnsdistcmd = (os.environ['DNSDISTBIN'] + " -C dnsdist_ecs_no_override.conf --acl 127.0.0.1/32 -l 127.0.0.1:" + str(_dnsDistPort)).split()
+
+    @classmethod
+    def startDNSDist(cls, shutUp=True):
+        print("Launching dnsdist..")
+        with open('dnsdist_ecs_no_override.conf', 'w') as conf:
+            conf.write(cls._config_template % str(cls._testServerPort))
+
+        print(' '.join(cls._dnsdistcmd))
+        if shutUp:
+            with open(os.devnull, 'w') as fdDevNull:
+                cls._dnsdist = subprocess.Popen(cls._dnsdistcmd, close_fds=True, stdout=fdDevNull, stderr=fdDevNull)
+        else:
+            cls._dnsdist = subprocess.Popen(cls._dnsdistcmd, close_fds=True)
+
+        time.sleep(1)
+
+        if cls._dnsdist.poll() is not None:
+            cls._dnsdist.terminate()
+            cls._dnsdist.wait()
+            sys.exit(cls._dnsdist.returncode)
+
+    def testWithoutEDNS(self):
+        """
+        Send a query without EDNS, check that the query
+        received by the responder has the correct ECS value
+        and that the response received from dnsdist does not
+        have an EDNS pseudo-RR.
+        """
+        name = 'withoutedns.ecs.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512)
+        response = dns.message.make_response(expectedQuery)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        expectedResponse.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        receivedResponse.id = expectedResponse.id
+        self.assertEquals(expectedQuery, receivedQuery)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        receivedResponse.id = expectedResponse.id
+        self.assertEquals(expectedQuery, receivedQuery)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+    def testWithEDNSNoECS(self):
+        """
+        Send a query with EDNS but no ECS value.
+        Check that the query received by the responder
+        has a valid ECS value and that the response
+        received from dnsdist contains an EDNS pseudo-RR.
+        """
+        name = 'withednsnoecs.ecs.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+        response = dns.message.make_response(expectedQuery)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        expectedResponse.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        receivedResponse.id = expectedResponse.id
+        self.assertEquals(expectedQuery, receivedQuery)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        receivedResponse.id = expectedResponse.id
+        self.assertEquals(expectedQuery, receivedQuery)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+    def testWithEDNSECS(self):
+        """
+        Send a query with EDNS and a crafted ECS value.
+        Check that the query received by the responder
+        has the initial ECS value (not overwritten)
+        and that the response received from dnsdist contains
+        an EDNS pseudo-RR.
+        """
+        name = 'withednsecs.ecs.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('1.2.3.4', 24)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+class TestEdnsClientSubnetOverride(DNSDistTest):
+    """
+    DNSdist is configured to add the EDNS0 Client Subnet
+    option, overwriting any existing value.
+    """
+
+    _dnsDistPort = 5340
+    _config_template = """
+    truncateTC(true)
+    block=newDNSName("powerdns.org.")
+    function blockFilter(remote, qname, qtype, dh)
+        if(qname:isPartOf(block))
+        then
+            print("Blocking *.powerdns.org")
+            return true
+        end
+        return false
+    end
+    setECSOverride(true)
+    setECSSourcePrefixV4(24)
+    setECSSourcePrefixV6(56)
+    newServer{address="127.0.0.1:%s", useClientSubnet=true}
+    """
+
+    _dnsdistcmd = (os.environ['DNSDISTBIN'] + " -C dnsdist_ecs_override.conf --acl 127.0.0.1/32 -l 127.0.0.1:" + str(_dnsDistPort)).split()
+
+    @classmethod
+    def startDNSDist(cls, shutUp=True):
+        print("Launching dnsdist..")
+        with open('dnsdist_ecs_override.conf', 'w') as conf:
+            conf.write(cls._config_template % str(cls._testServerPort))
+
+        print(' '.join(cls._dnsdistcmd))
+        if shutUp:
+            with open(os.devnull, 'w') as fdDevNull:
+                cls._dnsdist = subprocess.Popen(cls._dnsdistcmd, close_fds=True, stdout=fdDevNull, stderr=fdDevNull)
+        else:
+            cls._dnsdist = subprocess.Popen(cls._dnsdistcmd, close_fds=True)
+
+        time.sleep(1)
+
+        if cls._dnsdist.poll() is not None:
+            cls._dnsdist.terminate()
+            cls._dnsdist.wait()
+            sys.exit(cls._dnsdist.returncode)
+
+    def testWithoutEDNS(self):
+        """
+        Send a query without EDNS, check that the query
+        received by the responder has the correct ECS value
+        and that the response received from dnsdist does not
+        have an EDNS pseudo-RR.
+        """
+        name = 'withoutedns.overriden.ecs.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512)
+        response = dns.message.make_response(expectedQuery)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        expectedResponse.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        receivedResponse.id = expectedResponse.id
+        self.assertEquals(expectedQuery, receivedQuery)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        receivedResponse.id = expectedResponse.id
+        self.assertEquals(expectedQuery, receivedQuery)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+    def testWithEDNSNoECS(self):
+        """
+        Send a query with EDNS but no ECS value.
+        Check that the query received by the responder
+        has a valid ECS value and that the response
+        received from dnsdist contains an EDNS pseudo-RR.
+        """
+        name = 'withednsnoecs.overriden.ecs.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+        response = dns.message.make_response(expectedQuery)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        expectedResponse.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        receivedResponse.id = expectedResponse.id
+        self.assertEquals(expectedQuery, receivedQuery)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        receivedResponse.id = expectedResponse.id
+        self.assertEquals(expectedQuery, receivedQuery)
+        self.assertEquals(expectedResponse, receivedResponse)
+
+    def testWithEDNSECS(self):
+        """
+        Send a query with EDNS and a crafted ECS value.
+        Check that the query received by the responder
+        has an overwritten ECS value (not the initial one)
+        and that the response received from dnsdist contains
+        an EDNS pseudo-RR.
+        """
+        name = 'withednsecs.overriden.ecs.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('1.2.3.4', 24)
+        rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[rewrittenEcso])
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        receivedResponse.id = response.id
+        self.assertEquals(expectedQuery, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        receivedResponse.id = response.id
+        self.assertEquals(expectedQuery, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+
+if __name__ == '__main__':
+    unittest.main()
+    exit(0)