]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Set a correct EDNS OPT RR for self-generated answers
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 17 Aug 2018 15:59:54 +0000 (17:59 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 3 Sep 2018 09:11:45 +0000 (11:11 +0200)
19 files changed:
pdns/dnsdist-console.cc
pdns/dnsdist-ecs.cc
pdns/dnsdist-ecs.hh
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdist-lua-rules.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/dnsdist-rules.hh
pdns/dnsdistdist/docs/reference/config.rst
pdns/dnsdistdist/test-dnsdistrules_cc.cc
pdns/dnsparser.cc
pdns/dnsparser.hh
pdns/test-dnsdist_cc.cc
pdns/test-dnsdistpacketcache_cc.cc
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_EDNSSelfGenerated.py [new file with mode: 0644]

index d6f5b79f085537c076200085e6cb6817e077ea8b..2bbfb04374d66792e1db93923146e46a2032c471 100644 (file)
@@ -344,8 +344,10 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "DropResponseAction", true, "", "drop these packets" },
   { "DSTPortRule", true, "port", "matches questions received to the destination port specified" },
   { "dumpStats", true, "", "print all statistics we gather" },
-  { "exceedNXDOMAINs", true, "rate, seconds", "get set of addresses that exceed `rate` NXDOMAIN/s over `seconds` seconds" },
   { "dynBlockRulesGroup", true, "", "return a new DynBlockRulesGroup object" },
+  { "EDNSOptionRule", true, "optcode", "matches queries with the specified EDNS0 option present" },
+  { "ERCodeRule", true, "rcode", "matches responses with the specified extended rcode (EDNS0)" },
+  { "exceedNXDOMAINs", true, "rate, seconds", "get set of addresses that exceed `rate` NXDOMAIN/s over `seconds` seconds" },
   { "exceedQRate", true, "rate, seconds", "get set of address that exceed `rate` queries/s over `seconds` seconds" },
   { "exceedQTypeRate", true, "type, rate, seconds", "get set of address that exceed `rate` queries/s for queries of type `type` over `seconds` seconds" },
   { "exceedRespByterate", true, "rate, seconds", "get set of addresses that exceeded `rate` bytes/s answers over `seconds` seconds" },
@@ -388,6 +390,11 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "NoRecurseAction", true, "", "strip RD bit from the question, let it go through" },
   { "PoolAction", true, "poolname", "set the packet into the specified pool" },
   { "printDNSCryptProviderFingerprint", true, "\"/path/to/providerPublic.key\"", "display the fingerprint of the provided resolver public key" },
+  { "QNameLabelsCountRule", true, "min, max", "matches if the qname has less than `min` or more than `max` labels" },
+  { "QNameRule", true, "qname", "matches queries with the specified qname" },
+  { "QNameWireLengthRule", true, "min, max", "matches if the qname's length on the wire is less than `min` or more than `max` bytes" },
+  { "QTypeRule", true, "qtype", "matches queries with the specified qtype" },
+  { "RCodeRule", true, "rcode", "matches responses with the specified rcode" },
   { "RegexRule", true, "regex", "matches the query name against the supplied regex" },
   { "registerDynBPFFilter", true, "DynBPFFilter", "register this dynamic BPF filter into the web interface so that its counters are displayed" },
   { "RemoteLogAction", true, "RemoteLogger [, alterFunction]", "send the content of this query to a remote logger via Protocol Buffer. `alterFunction` is a callback, receiving a DNSQuestion and a DNSDistProtoBufMessage, that can be used to modify the Protocol Buffer content, for example for anonymization purposes" },
@@ -398,15 +405,9 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "rmSelfAnsweredResponseRule", true, "id", "remove self-answered response rule in position 'id', or whose uuid matches if 'id' is an UUID string" },
   { "rmServer", true, "n", "remove server with index n" },
   { "roundrobin", false, "", "Simple round robin over available servers" },
-  { "QNameLabelsCountRule", true, "min, max", "matches if the qname has less than `min` or more than `max` labels" },
-  { "QNameRule", true, "qname", "matches queries with the specified qname" },
-  { "QNameWireLengthRule", true, "min, max", "matches if the qname's length on the wire is less than `min` or more than `max` bytes" },
-  { "QTypeRule", true, "qtype", "matches queries with the specified qtype" },
-  { "RCodeRule", true, "rcode", "matches responses with the specified rcode" },
-  { "ERCodeRule", true, "rcode", "matches responses with the specified extended rcode (EDNS0)" },
-  { "EDNSOptionRule", true, "optcode", "matches queries with the specified EDNS0 option present" },
   { "sendCustomTrap", true, "str", "send a custom `SNMP` trap from Lua, containing the `str` string"},
   { "setACL", true, "{netmask, netmask}", "replace the ACL set with these netmasks. Use `setACL({})` to reset the list, meaning no one can use us" },
+  { "setAddEDNSToSelfGeneratedResponses", true, "add", "set whether to add EDNS to self-generated responses, provided that the initial query had EDNS" },
   { "setAPIWritable", true, "bool, dir", "allow modifications via the API. if `dir` is set, it must be a valid directory where the configuration files will be written by the API" },
   { "setConsoleACL", true, "{netmask, netmask}", "replace the console ACL set with these netmasks" },
   { "setConsoleConnectionsLogging", true, "enabled", "whether to log the opening and closing of console connections" },
@@ -423,6 +424,7 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "setMaxTCPQueriesPerConnection", true, "n", "set the maximum number of queries in an incoming TCP connection. 0 means unlimited" },
   { "setMaxTCPQueuedConnections", true, "n", "set the maximum number of TCP connections queued (waiting to be picked up by a client thread)" },
   { "setMaxUDPOutstanding", true, "n", "set the maximum number of outstanding UDP queries to a given backend server. This can only be set at configuration time and defaults to 10240" },
+  { "setPayloadSizeOnSelfGeneratedAnswers", true, "add", "set the UDP payload size advertised via EDNS on self-generated responses" },
   { "setPoolServerPolicy", true, "policy, pool", "set the server selection policy for this pool to that policy" },
   { "setPoolServerPolicy", true, "name, func, pool", "set the server selection policy for this pool to one named 'name' and provided by 'function'" },
   { "setQueryCount", true, "bool", "set whether queries should be counted" },
index 1d35a84288cab231639fbe1afa1ad50d3e48677e..cf8b73e8a6a8367c536e1053f0a51f1347236551 100644 (file)
 /* when we add EDNS to a query, we don't want to advertise
    a large buffer size */
 size_t g_EdnsUDPPayloadSize = 512;
+uint16_t g_PayloadSizeSelfGenAnswers{s_udpIncomingBufferSize};
+
 /* draft-ietf-dnsop-edns-client-subnet-04 "11.1.  Privacy" */
 uint16_t g_ECSSourcePrefixV4 = 24;
 uint16_t g_ECSSourcePrefixV6 = 56;
 
 bool g_ECSOverride{false};
+bool g_addEDNSToSelfGeneratedResponses{true};
 
 int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>& newContent)
 {
@@ -241,20 +244,21 @@ static void generateECSOption(const ComboAddress& source, string& res, uint16_t
   generateEDNSOption(EDNSOptionCode::ECS, payload, res);
 }
 
-void generateOptRR(const std::string& optRData, string& res)
+void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayloadSize, bool dnssecOK)
 {
   const uint8_t name = 0;
   dnsrecordheader dh;
   EDNS0Record edns0;
   edns0.extRCode = 0;
   edns0.version = 0;
-  edns0.extFlags = 0;
-  
+  edns0.extFlags = dnssecOK ? htons(EDNS_HEADER_FLAG_DO) : 0;
+
   dh.d_type = htons(QType::OPT);
-  dh.d_class = htons(g_EdnsUDPPayloadSize);
+  dh.d_class = htons(udpPayloadSize);
   static_assert(sizeof(EDNS0Record) == sizeof(dh.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
   memcpy(&dh.d_ttl, &edns0, sizeof edns0);
   dh.d_clen = htons((uint16_t) optRData.length());
+  res.reserve(sizeof(name) + sizeof(dh) + optRData.length());
   res.assign((const char *) &name, sizeof name);
   res.append((const char *) &dh, sizeof dh);
   res.append(optRData.c_str(), optRData.length());
@@ -352,7 +356,7 @@ bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u
     struct dnsheader* dh = (struct dnsheader*) packet;
     string optRData;
     generateECSOption(remote, optRData, ecsPrefixLength);
-    generateOptRR(optRData, EDNSRR);
+    generateOptRR(optRData, EDNSRR, g_EdnsUDPPayloadSize, false);
 
     /* does it fit in the existing buffer? */
     if (packetSize - *len <= EDNSRR.size()) {
@@ -554,3 +558,119 @@ int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uin
 
   return 0;
 }
+
+bool addEDNS(DNSQuestion& dq, bool dnssecOK)
+{
+  if (dq.dh->arcount != 0) {
+    return false;
+  }
+
+  std::string optRecord;
+  generateOptRR(std::string(), optRecord, g_PayloadSizeSelfGenAnswers, dnssecOK);
+
+  if (optRecord.size() >= dq.size || (dq.size - optRecord.size()) < dq.len) {
+    return false;
+  }
+
+  char * optPtr = reinterpret_cast<char*>(dq.dh) + dq.len;
+  memcpy(optPtr, optRecord.data(), optRecord.size());
+  dq.len += optRecord.size();
+  dq.dh->arcount = htons(1);
+
+  return true;
+}
+
+bool addEDNSToQueryTurnedResponse(DNSQuestion& dq)
+{
+  char* optRDLen = nullptr;
+  /* remaining is at least the size of the rdlen + the options if any + the following records if any */
+  size_t remaining = 0;
+
+  int res = getEDNSOptionsStart(reinterpret_cast<char*>(dq.dh), dq.consumed, dq.len, &optRDLen, &remaining);
+
+  if (res != 0) {
+    /* if the initial query did not have EDNS0, we are done */
+    return true;
+  }
+
+  const size_t existingOptLen = /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2 + remaining;
+  if (existingOptLen >= dq.len) {
+    /* something is wrong, bail out */
+    return false;
+  }
+
+  char * optPtr = (optRDLen - (/* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2));
+
+  const uint8_t* zPtr = (const uint8_t*) optPtr + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE;
+  uint16_t z = 0x100 * (*zPtr) + *(zPtr + 1);
+  bool dnssecOK = z & EDNS_HEADER_FLAG_DO;
+
+  /* remove the existing OPT record, and everything else that follows (any SIG or TSIG would be useless anyway) */
+  dq.len -= existingOptLen;
+  dq.dh->arcount = 0;
+
+  if (g_addEDNSToSelfGeneratedResponses) {
+    /* now we need to add a new OPT record */
+    return addEDNS(dq, dnssecOK);
+  }
+
+  /* otherwise we are just fine */
+  return true;
+}
+
+// goal in life - if you send us a reasonably normal packet, we'll get Z for you, otherwise 0
+int getEDNSZ(const DNSQuestion& dq)
+try
+{
+  if (ntohs(dq.dh->qdcount) != 1 || dq.dh->ancount != 0 || ntohs(dq.dh->arcount) != 1 || dq.dh->nscount != 0) {
+    return 0;
+  }
+
+  if (dq.len <= sizeof(dnsheader)) {
+    return 0;
+  }
+
+  size_t pos = sizeof(dnsheader) + dq.consumed + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
+
+  if (dq.len <= (pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE)) {
+    return 0;
+  }
+
+  const char* packet = reinterpret_cast<const char*>(dq.dh);
+
+  if (packet[pos] != 0) {
+    /* not root, so not a valid OPT record */
+    return 0;
+  }
+
+  pos++;
+
+  uint16_t qtype = (const unsigned char)packet[pos]*256 + (const unsigned char)packet[pos+1];
+  pos += DNS_TYPE_SIZE;
+  pos += DNS_CLASS_SIZE;
+
+  if (qtype != QType::OPT || (pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1) >= dq.len) {
+    return 0;
+  }
+
+  const uint8_t* z = (const uint8_t*) packet + pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE;
+  return 0x100 * (*z) + *(z+1);
+}
+catch(...)
+{
+  return 0;
+}
+
+bool queryHasEDNS(const DNSQuestion& dq)
+{
+  char * optRDLen = nullptr;
+  size_t ecsRemaining = 0;
+
+  int res = getEDNSOptionsStart(reinterpret_cast<char*>(dq.dh), dq.consumed, dq.len, &optRDLen, &ecsRemaining);
+  if (res == 0) {
+    return true;
+  }
+
+  return false;
+}
+
index fad0b77c495b05872d8a046eb1d733139a0a9b5c..e0b92a42c625aff47b8e27ce930eb2e9b450df14 100644 (file)
  */
 #pragma once
 
+extern size_t g_EdnsUDPPayloadSize;
+extern uint16_t g_PayloadSizeSelfGenAnswers;
+
 int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>& newContent);
 int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * optLen, bool * last);
 bool handleEDNSClientSubnet(char * packet, size_t packetSize, unsigned int consumed, uint16_t * len, bool* ednsAdded, bool* ecsAdded, const ComboAddress& remote, bool overrideExisting, uint16_t ecsPrefixLength);
-void generateOptRR(const std::string& optRData, string& res);
+void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayloadSize, bool dnssecOK);
 int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove);
 int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uint16_t optionCodeToSkip, vector<uint8_t>& newContent);
 int getEDNSOptionsStart(char* packet, const size_t offset, const size_t len, char ** optRDLen, size_t * remaining);
 bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind);
+bool addEDNS(DNSQuestion& dq, bool dnssecOK);
+bool addEDNSToQueryTurnedResponse(DNSQuestion& dq);
+
+int getEDNSZ(const DNSQuestion& dq);
+bool queryHasEDNS(const DNSQuestion& dq);
+
index 3df44d45c46afe55a01973d51ecc5c8585ac5c49..7724fda54c10d63546cffae56b2d017866d2cf5a 100644 (file)
@@ -402,6 +402,13 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, string* ruleresult) c
     return Action::None;
   }
 
+  bool dnssecOK = false;
+  bool hadEDNS = false;
+  if (g_addEDNSToSelfGeneratedResponses && queryHasEDNS(*dq)) {
+    hadEDNS = true;
+    dnssecOK = getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO;
+  }
+
   dq->len = sizeof(dnsheader) + consumed + 4; // there goes your EDNS
   char* dest = ((char*)dq->dh) + dq->len;
 
@@ -450,6 +457,10 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, string* ruleresult) c
 
   dq->dh->ancount = htons(dq->dh->ancount);
 
+  if (hadEDNS && g_addEDNSToSelfGeneratedResponses) {
+    addEDNS(*dq, dnssecOK);
+  }
+
   return Action::HeaderModify;
 }
 
@@ -471,7 +482,7 @@ public:
     generateEDNSOption(d_code, mac, optRData);
 
     string res;
-    generateOptRR(optRData, res);
+    generateOptRR(optRData, res, g_EdnsUDPPayloadSize, false);
 
     if ((dq->size - dq->len) < res.length())
       return Action::None;
index 9c033bbae6069eef123884ebe4ea7b473a8426bb..abb3a4fc2fc7885cbdd00f2469931e7e248fca31 100644 (file)
@@ -20,6 +20,7 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
 #include "dnsdist.hh"
+#include "dnsdist-ecs.hh"
 #include "dnsdist-lua.hh"
 #include "dnsparser.hh"
 
@@ -52,7 +53,7 @@ void setupLuaBindingsDNSQuestion()
       }
     );
   g_lua.registerFunction<bool(DNSQuestion::*)()>("getDO", [](const DNSQuestion& dq) {
-      return getEDNSZ((const char*)dq.dh, dq.len) & EDNS_HEADER_FLAG_DO;
+      return getEDNSZ(dq) & EDNS_HEADER_FLAG_DO;
     });
   g_lua.registerFunction<void(DNSQuestion::*)(std::string)>("sendTrap", [](const DNSQuestion& dq, boost::optional<std::string> reason) {
 #ifdef HAVE_NET_SNMP
index bb501d2b3b76ee5de7d2411e4b2281577724139e..973047ae7fc89ee60427ea9f67af1f48a4b011fd 100644 (file)
@@ -322,7 +322,7 @@ void setupLuaRules()
       sw.start();
       for(int n=0; n < times; ++n) {
         const item& i = items[n % items.size()];
-        DNSQuestion dq(&i.qname, i.qtype, i.qclass, &i.rem, &i.rem, (struct dnsheader*)&i.packet[0], i.packet.size(), i.packet.size(), false, &sw.d_start);
+        DNSQuestion dq(&i.qname, i.qtype, i.qclass, 0, &i.rem, &i.rem, (struct dnsheader*)&i.packet[0], i.packet.size(), i.packet.size(), false, &sw.d_start);
         if(rule->matches(&dq))
           matches++;
       }
index 457b7fe27c14ba767fe826c9bb9216d2304e1704..ce7e6dafe434f5d4c0a9e50f30784d9f4c26bcce 100644 (file)
@@ -33,6 +33,7 @@
 
 #include "dnsdist.hh"
 #include "dnsdist-console.hh"
+#include "dnsdist-ecs.hh"
 #include "dnsdist-lua.hh"
 #include "dnsdist-rings.hh"
 
@@ -1480,6 +1481,24 @@ void setupLuaConfig(bool client)
 #endif
     });
 
+  g_lua.writeFunction("setAddEDNSToSelfGeneratedResponses", [](bool add) {
+      g_addEDNSToSelfGeneratedResponses = add;
+  });
+
+  g_lua.writeFunction("setPayloadSizeOnSelfGeneratedAnswers", [](uint16_t payloadSize) {
+      if (payloadSize < 512) {
+        warnlog("setPayloadSizeOnSelfGeneratedAnswers() is set too low, using 512 instead!");
+        g_outputBuffer="setPayloadSizeOnSelfGeneratedAnswers() is set too low, using 512 instead!";
+        payloadSize = 512;
+      }
+      if (payloadSize > s_udpIncomingBufferSize) {
+        warnlog("setPayloadSizeOnSelfGeneratedAnswers() is set too high, capping to %d instead!", s_udpIncomingBufferSize);
+        g_outputBuffer="setPayloadSizeOnSelfGeneratedAnswers() is set too high, capping to " + std::to_string(s_udpIncomingBufferSize) + " instead";
+        payloadSize = s_udpIncomingBufferSize;
+      }
+      g_PayloadSizeSelfGenAnswers = payloadSize;
+  });
+
   g_lua.writeFunction("addTLSLocal", [client](const std::string& addr, boost::variant<std::string, std::vector<std::pair<int,std::string>>> certFiles, boost::variant<std::string, std::vector<std::pair<int,std::string>>> keyFiles, boost::optional<localbind_t> vars) {
         if (client)
           return;
index 1bc7be143cce0027902a7c8e1d0f9ed961009949..5ca8ba071777e332f2440c8f579b2a70bc09730f 100644 (file)
@@ -358,16 +358,16 @@ void* tcpClientThread(int pipefd)
        uint16_t qtype, qclass;
        unsigned int consumed = 0;
        DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-       DNSQuestion dq(&qname, qtype, qclass, &dest, &ci.remote, dh, queryBuffer.size(), qlen, true, &queryRealTime);
+       DNSQuestion dq(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, queryBuffer.size(), qlen, true, &queryRealTime);
 
        if (!processQuery(holders, dq, poolname, &delayMsec, now)) {
          goto drop;
        }
 
        if(dq.dh->qr) { // something turned it into a response
-          restoreFlags(dh, origFlags);
+          fixUpQueryTurnedResponse(dq, origFlags);
 
-          DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, dq.len, true, &queryRealTime);
+          DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, dq.len, true, &queryRealTime);
 #ifdef HAVE_PROTOBUF
           dr.uniqueId = dq.uniqueId;
 #endif
@@ -419,7 +419,7 @@ void* tcpClientThread(int pipefd)
           uint16_t cachedResponseSize = sizeof cachedResponse;
           uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL;
           if (packetCache->get(dq, (uint16_t) consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, subnet, allowExpired)) {
-            DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.local, dq.remote, (dnsheader*) cachedResponse, sizeof cachedResponse, cachedResponseSize, true, &queryRealTime);
+            DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, (dnsheader*) cachedResponse, sizeof cachedResponse, cachedResponseSize, true, &queryRealTime);
 #ifdef HAVE_PROTOBUF
             dr.uniqueId = dq.uniqueId;
 #endif
@@ -449,7 +449,7 @@ void* tcpClientThread(int pipefd)
             dq.dh->rcode = RCode::ServFail;
             dq.dh->qr = true;
 
-            DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, dq.len, false, &queryRealTime);
+            DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, dq.len, false, &queryRealTime);
 #ifdef HAVE_PROTOBUF
             dr.uniqueId = dq.uniqueId;
 #endif
@@ -581,7 +581,8 @@ void* tcpClientThread(int pipefd)
           break;
         }
 
-        if (firstPacket && !responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote)) {
+        consumed = 0;
+        if (firstPacket && !responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote, consumed)) {
           break;
         }
         firstPacket=false;
@@ -590,7 +591,7 @@ void* tcpClientThread(int pipefd)
         }
 
         dh = (struct dnsheader*) response;
-        DNSResponse dr(&qname, qtype, qclass, &dest, &ci.remote, dh, responseSize, responseLen, true, &queryRealTime);
+        DNSResponse dr(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, responseSize, responseLen, true, &queryRealTime);
 #ifdef HAVE_PROTOBUF
         dr.uniqueId = dq.uniqueId;
 #endif
index b0c17bde51fe4555e9d9a154910f4f8744ade2b5..8abb56e41bf3ec287aad8da51bfcf6f4dbd7a440 100644 (file)
@@ -142,16 +142,12 @@ bool g_servFailOnNoPolicy{false};
 bool g_truncateTC{false};
 bool g_fixupCase{0};
 
-static const size_t s_udpIncomingBufferSize{1500};
-
-static void truncateTC(const char* packet, uint16_t* len)
+static void truncateTC(char* packet, uint16_t* len, unsigned int consumed)
 try
 {
-  unsigned int consumed;
-  DNSName qname(packet, *len, sizeof(dnsheader), false, 0, 0, &consumed);
   *len=(uint16_t) (sizeof(dnsheader)+consumed+DNS_TYPE_SIZE+DNS_CLASS_SIZE);
-  struct dnsheader* dh =(struct dnsheader*)packet;
-  dh->ancount = dh->arcount = dh->nscount=0;
+  struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet);
+  dh->ancount = dh->arcount = dh->nscount = 0;
 }
 catch(...)
 {
@@ -201,10 +197,9 @@ void doLatencyStats(double udiff)
   doAvg(g_stats.latencyAvg1000000, udiff, 1000000);
 }
 
-bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote)
+bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& consumed)
 {
   uint16_t rqtype, rqclass;
-  unsigned int consumed;
   DNSName rqname;
   const struct dnsheader* dh = (struct dnsheader*) response;
 
@@ -253,6 +248,13 @@ void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
   *flags |= origFlags;
 }
 
+bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
+{
+  restoreFlags(dq.dh, origFlags);
+
+  return addEDNSToQueryTurnedResponse(dq);
+}
+
 bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom)
 {
   struct dnsheader* dh = (struct dnsheader*) *response;
@@ -458,7 +460,8 @@ try {
         */
         ids->age = 0;
 
-        if (!responseContentMatches(response, responseLen, ids->qname, ids->qtype, ids->qclass, dss->remote)) {
+        unsigned int consumed = 0;
+        if (!responseContentMatches(response, responseLen, ids->qname, ids->qtype, ids->qclass, dss->remote, consumed)) {
           continue;
         }
 
@@ -472,13 +475,13 @@ try {
         }
 
         if(dh->tc && g_truncateTC) {
-          truncateTC(response, &responseLen);
+          truncateTC(response, &responseLen, consumed);
         }
 
         dh->id = ids->origID;
 
         uint16_t addRoom = 0;
-        DNSResponse dr(&ids->qname, ids->qtype, ids->qclass, &ids->origDest, &ids->origRemote, dh, sizeof(packet), responseLen, false, &ids->sentTime.d_start);
+        DNSResponse dr(&ids->qname, ids->qtype, ids->qclass, consumed, &ids->origDest, &ids->origRemote, dh, sizeof(packet), responseLen, false, &ids->sentTime.d_start);
 #ifdef HAVE_PROTOBUF
         dr.uniqueId = ids->uniqueId;
 #endif
@@ -892,39 +895,6 @@ NumberedServerVector getDownstreamCandidates(const pools_t& pools, const std::st
   return pool->getServers();
 }
 
-// goal in life - if you send us a reasonably normal packet, we'll get Z for you, otherwise 0
-int getEDNSZ(const char* packet, unsigned int len)
-try
-{
-  struct dnsheader* dh =(struct dnsheader*)packet;
-
-  if(ntohs(dh->qdcount) != 1 || dh->ancount!=0 || ntohs(dh->arcount)!=1 || dh->nscount!=0)
-    return 0;
-
-  if (len <= sizeof(dnsheader))
-    return 0;
-
-  unsigned int consumed;
-  DNSName qname(packet, len, sizeof(dnsheader), false, 0, 0, &consumed);
-  size_t pos = consumed + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
-  uint16_t qtype, qclass;
-
-  if (len <= (sizeof(dnsheader)+pos))
-    return 0;
-
-  DNSName aname(packet, len, sizeof(dnsheader)+pos, true, &qtype, &qclass, &consumed);
-
-  if(qtype!=QType::OPT || sizeof(dnsheader)+pos+consumed+DNS_TYPE_SIZE+DNS_CLASS_SIZE+EDNS_EXTENDED_RCODE_SIZE+EDNS_VERSION_SIZE+1 >= len)
-    return 0;
-
-  uint8_t* z = (uint8_t*)packet+sizeof(dnsheader)+pos+consumed+DNS_TYPE_SIZE+DNS_CLASS_SIZE+EDNS_EXTENDED_RCODE_SIZE+EDNS_VERSION_SIZE;
-  return 0x100 * (*z) + *(z+1);
-}
-catch(...)
-{
-  return 0;
-}
-
 static void spoofResponseFromString(DNSQuestion& dq, const string& spoofContent)
 {
   string result;
@@ -1355,7 +1325,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     uint16_t qtype, qclass;
     unsigned int consumed = 0;
     DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-    DNSQuestion dq(&qname, qtype, qclass, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false, &queryRealTime);
+    DNSQuestion dq(&qname, qtype, qclass, consumed, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false, &queryRealTime);
 
     if (!processQuery(holders, dq, poolname, &delayMsec, now))
     {
@@ -1363,13 +1333,13 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     }
 
     if(dq.dh->qr) { // something turned it into a response
-      restoreFlags(dh, origFlags);
+      fixUpQueryTurnedResponse(dq, origFlags);
 
       if (!cs.muted) {
         char* response = query;
         uint16_t responseLen = dq.len;
 
-        DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.local, dq.remote, reinterpret_cast<dnsheader*>(response), dq.size, responseLen, false, &queryRealTime);
+        DNSResponse dr(dq.qname, dq.qtype, dq.qclass, consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(response), dq.size, responseLen, false, &queryRealTime);
 #ifdef HAVE_PROTOBUF
         dr.uniqueId = dq.uniqueId;
 #endif
@@ -1433,7 +1403,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       uint16_t cachedResponseSize = dq.size;
       uint32_t allowExpired = ss ? 0 : g_staleCacheEntriesTTL;
       if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKey, subnet, allowExpired)) {
-        DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, cachedResponseSize, false, &queryRealTime);
+        DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, cachedResponseSize, false, &queryRealTime);
 #ifdef HAVE_PROTOBUF
         dr.uniqueId = dq.uniqueId;
 #endif
@@ -1479,7 +1449,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
         dq.dh->rcode = RCode::ServFail;
         dq.dh->qr = true;
 
-        DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.local, dq.remote, reinterpret_cast<dnsheader*>(response), dq.size, responseLen, false, &queryRealTime);
+        DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(response), dq.size, responseLen, false, &queryRealTime);
 #ifdef HAVE_PROTOBUF
         dr.uniqueId = dq.uniqueId;
 #endif
index ca1c100ca1ee5c95776310d3688ef5f7a8bbddb6..c66e1f217b5a005958be8871e3f83e2679e69841 100644 (file)
@@ -59,8 +59,8 @@ typedef std::unordered_map<string, string> QTag;
 
 struct DNSQuestion
 {
-  DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t queryLen, bool isTcp, const struct timespec* queryTime_):
-    qname(name), qtype(type), qclass(class_), local(lc), remote(rem), dh(header), size(bufferSize), len(queryLen), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), tempFailureTTL(boost::none), tcp(isTcp), queryTime(queryTime_), ecsOverride(g_ECSOverride) { }
+  DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, unsigned int consumed_, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t queryLen, bool isTcp, const struct timespec* queryTime_):
+    qname(name), qtype(type), qclass(class_), local(lc), remote(rem), dh(header), size(bufferSize), consumed(consumed_), len(queryLen), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), tempFailureTTL(boost::none), tcp(isTcp), queryTime(queryTime_), ecsOverride(g_ECSOverride) { }
 
 #ifdef HAVE_PROTOBUF
   boost::optional<boost::uuids::uuid> uniqueId;
@@ -74,6 +74,7 @@ struct DNSQuestion
   std::shared_ptr<QTag> qTag{nullptr};
   struct dnsheader* dh;
   size_t size;
+  unsigned int consumed{0};
   uint16_t len;
   uint16_t ecsPrefixLength;
   boost::optional<uint32_t> tempFailureTTL;
@@ -88,8 +89,8 @@ struct DNSQuestion
 
 struct DNSResponse : DNSQuestion
 {
-  DNSResponse(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t responseLen, bool isTcp, const struct timespec* queryTime_):
-    DNSQuestion(name, type, class_, lc, rem, header, bufferSize, responseLen, isTcp, queryTime_) { }
+  DNSResponse(const DNSName* name, uint16_t type, uint16_t class_, unsigned int consumed, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t responseLen, bool isTcp, const struct timespec* queryTime_):
+    DNSQuestion(name, type, class_, consumed, lc, rem, header, bufferSize, responseLen, isTcp, queryTime_) { }
 };
 
 /* so what could you do:
@@ -975,8 +976,7 @@ std::shared_ptr<DownstreamState> wrandom(const NumberedServerVector& servers, co
 std::shared_ptr<DownstreamState> whashed(const NumberedServerVector& servers, const DNSQuestion* dq);
 std::shared_ptr<DownstreamState> chashed(const NumberedServerVector& servers, const DNSQuestion* dq);
 std::shared_ptr<DownstreamState> roundrobin(const NumberedServerVector& servers, const DNSQuestion* dq);
-int getEDNSZ(const char* packet, unsigned int len);
-uint16_t getEDNSOptionCode(const char * packet, size_t len);
+
 void dnsdistWebserverThread(int sock, const ComboAddress& local, const string& password, const string& apiKey, const boost::optional<std::map<std::string, std::string> >&);
 bool getMsgLen32(int fd, uint32_t* len);
 bool putMsgLen32(int fd, uint32_t len);
@@ -987,9 +987,10 @@ void setLuaSideEffect();   // set to report a side effect, cancelling all _no_ s
 bool getLuaNoSideEffect(); // set if there were only explicit declarations of _no_ side effect
 void resetLuaSideEffect(); // reset to indeterminate state
 
-bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote);
+bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& consumed);
 bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int* delayMsec, const struct timespec& now);
 bool processResponse(LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, int* delayMsec);
+bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags);
 bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom);
 void restoreFlags(struct dnsheader* dh, uint16_t origFlags);
 bool checkQueryHeaders(const struct dnsheader* dh);
@@ -1008,3 +1009,6 @@ bool addXPF(DNSQuestion& dq, uint16_t optionCode);
 extern bool g_snmpEnabled;
 extern bool g_snmpTrapsEnabled;
 extern DNSDistSNMPAgent* g_snmpAgent;
+extern bool g_addEDNSToSelfGeneratedResponses;
+
+static const size_t s_udpIncomingBufferSize{1500};
index 2a4c64ac29854058f41c7f8c3b1e23a35b7e0d44..2321902d92ca3591cdd66fb1ee6c1712a721d8cd 100644 (file)
@@ -377,7 +377,7 @@ public:
   }
   bool matches(const DNSQuestion* dq) const override
   {
-    return dq->dh->cd || (getEDNSZ((const char*)dq->dh, dq->len) & EDNS_HEADER_FLAG_DO);    // turns out dig sets ad by default..
+    return dq->dh->cd || (getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO);    // turns out dig sets ad by default..
   }
 
   string toString() const override
index a0554dcad830e09a8b136c3712457b220d14a5de..ab9b3a3d9353c252ca9c108c70a291073bef2875 100644 (file)
@@ -947,3 +947,32 @@ TLSFrontend
 
   :param str certFile(s): The path to a X.509 certificate file in PEM format, or a list of paths to such files.
   :param str keyFile(s): The path to the private key file corresponding to the certificate, or a list of paths to such files, whose order should match the certFile(s) ones.
+
+EDNS on Self-generated answers
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+There are several mechanisms in dnsdist that turn an existing query into an answer right away,
+without reaching out to the backend, including :func:`SpoofAction`, :func:`RCodeAction`, :func:`TCAction`
+and returning a response from ``Lua``. Those responses should, according to :rfc:`6891`, contain an ``OPT``
+record if the received request had one, which is the case by default and can be disabled using
+:func:`setAddEDNSToSelfGeneratedResponses`.
+
+We must, however, provide a responder's maximum payload size in this record, and we can't easily know the
+maximum payload size of the actual backend so we need to provide one. The default value is 1500 and can be
+overriden using :func:`setPayloadSizeOnSelfGeneratedAnswers`.
+
+.. function:: setAddEDNSToSelfGeneratedResponses(add)
+
+  .. versionadded:: 1.3.3
+
+  Whether to add EDNS to self-generated responses, provided that the initial query had EDNS.
+
+  :param bool add: Whether to add EDNS, default is true.
+
+.. function:: setPayloadSizeOnSelfGeneratedAnswers(size)
+
+  .. versionadded:: 1.3.3
+
+  Set the UDP payload size advertised via EDNS on self-generated responses.
+
+  :param int size: The responder's maximum UDP payload size, in bytes. Default is 1500.
index 32fd6734358077994faeb63e35f4148e3c5e75b0..a402864068bb8e7b6e8b698b950cc5ad66b67db0 100644 (file)
@@ -32,7 +32,7 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) {
   /* the internal QPS limiter does not use the real time */
   gettime(&expiredTime);
 
-  DNSQuestion dq(&qname, qtype, qclass, &lc, &rem, dh, bufferSize, queryLen, isTcp, &queryRealTime);
+  DNSQuestion dq(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, dh, bufferSize, queryLen, isTcp, &queryRealTime);
 
   for (size_t idx = 0; idx < maxQPS; idx++) {
     /* let's use different source ports, it shouldn't matter */
index a5fd5b4f0e7a025a54effca5626e22d5526bf93a..2fc13ac381d4c0791f64e75af2a6c136d09b687d 100644 (file)
@@ -933,3 +933,41 @@ uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t sectio
   }
   return result;
 }
+
+uint16_t getEDNSUDPPayloadSize(const char* packet, size_t length)
+{
+  if (length < sizeof(dnsheader)) {
+    return 0;
+  }
+
+  try
+  {
+    const dnsheader* dh = (const dnsheader*) packet;
+    DNSPacketMangler dpm(const_cast<char*>(packet), length);
+
+    const uint16_t qdcount = ntohs(dh->qdcount);
+    for(size_t n = 0; n < qdcount; ++n) {
+      dpm.skipLabel();
+      /* type and class */
+      dpm.skipBytes(4);
+    }
+    const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
+    for(size_t n = 0; n < numrecords; ++n) {
+      dpm.skipLabel();
+      const uint16_t dnstype = dpm.get16BitInt();
+      const uint16_t dnsclass = dpm.get16BitInt();
+
+      if(dnstype == QType::OPT) {
+        return dnsclass;
+      }
+
+      /* TTL */
+      dpm.skipBytes(4);
+      dpm.skipRData();
+    }
+  }
+  catch(...)
+  {
+  }
+  return 0;
+}
index 5892b5b5879ad5c9b2e7246e57617dddb8c8709a..b0182ebb5edab2498e82766e86241986e34331a9 100644 (file)
@@ -401,6 +401,7 @@ void editDNSPacketTTL(char* packet, size_t length, std::function<uint32_t(uint8_
 uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA=nullptr);
 uint32_t getDNSPacketLength(const char* packet, size_t length);
 uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type);
+uint16_t getEDNSUDPPayloadSize(const char* packet, size_t length);;
 
 template<typename T>
 std::shared_ptr<T> getRR(const DNSRecord& dr)
index 75d0a5a62e1a16472bbbb7cd814314ef8950a23d..9bd45c2e5e03f7e2d087f8d7b5d4e3c903339700 100644 (file)
@@ -744,4 +744,218 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenLastOption) {
   validateResponse((const char *) newResponse.data(), newResponse.size(), true, 1);
 }
 
+static DNSQuestion getDNSQuestion(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& lc, const ComboAddress& rem, const struct timespec& realTime, vector<uint8_t>& query, size_t len)
+{
+  dnsheader* dh = reinterpret_cast<dnsheader*>(query.data());
+
+  DNSQuestion dq(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, dh, query.size(), len, false, &realTime);
+  return dq;
+}
+
+static DNSQuestion turnIntoResponse(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& lc, const ComboAddress& rem, const struct timespec& queryRealTime, vector<uint8_t>&  query, bool resizeBuffer=true)
+{
+  size_t length = query.size();
+  if (resizeBuffer) {
+    query.resize(4096);
+  }
+
+  auto dq = getDNSQuestion(qname, qtype, qclass, lc, rem, queryRealTime, query, length);
+
+  BOOST_CHECK(addEDNSToQueryTurnedResponse(dq));
+
+  return dq;
+}
+
+static int getZ(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, vector<uint8_t>& query)
+{
+  ComboAddress lc("127.0.0.1");
+  ComboAddress rem("127.0.0.1");
+  struct timespec queryRealTime;
+  gettime(&queryRealTime, true);
+  size_t length = query.size();
+  DNSQuestion dq = getDNSQuestion(qname, qtype, qclass, lc, rem, queryRealTime, query, length);
+
+  return getEDNSZ(dq);
+}
+
+BOOST_AUTO_TEST_CASE(test_getEDNSZ) {
+
+  DNSName qname("www.powerdns.com.");
+  uint16_t qtype = QType::A;
+  uint16_t qclass = QClass::IN;
+  EDNSSubnetOpts ecsOpts;
+  ecsOpts.source = Netmask(ComboAddress("127.0.0.1"), ECSSourcePrefixV4);
+  string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts);
+  EDNSCookiesOpt cookiesOpt;
+  cookiesOpt.client = string("deadbeef");
+  cookiesOpt.server = string("deadbeef");
+  string cookiesOptionStr = makeEDNSCookiesOptString(cookiesOpt);
+  DNSPacketWriter::optvect_t opts;
+  opts.push_back(make_pair(EDNSOptionCode::COOKIE, cookiesOptionStr));
+  opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOptionStr));
+
+  {
+    /* no EDNS */
+    vector<uint8_t> query;
+    DNSPacketWriter pw(query, qname, qtype, qclass, 0);
+    pw.commit();
+
+    BOOST_CHECK_EQUAL(getZ(qname, qtype, qclass, query), 0);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(query.data()), query.size()), 0);
+  }
+
+  {
+    /* truncated EDNS */
+    vector<uint8_t> query;
+    DNSPacketWriter pw(query, qname, qtype, qclass, 0);
+    pw.addOpt(512, 0, EDNS_HEADER_FLAG_DO);
+    pw.commit();
+
+    query.resize(query.size() - (/* RDLEN */ sizeof(uint16_t) + /* last byte of TTL / Z */ 1));
+    BOOST_CHECK_EQUAL(getZ(qname, qtype, qclass, query), 0);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(query.data()), query.size()), 512);
+  }
+
+  {
+    /* valid EDNS, no options, DO not set */
+    vector<uint8_t> query;
+    DNSPacketWriter pw(query, qname, qtype, qclass, 0);
+    pw.addOpt(512, 0, 0);
+    pw.commit();
+
+    BOOST_CHECK_EQUAL(getZ(qname, qtype, qclass, query), 0);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(query.data()), query.size()), 512);
+  }
+
+  {
+    /* valid EDNS, no options, DO set */
+    vector<uint8_t> query;
+    DNSPacketWriter pw(query, qname, qtype, qclass, 0);
+    pw.addOpt(512, 0, EDNS_HEADER_FLAG_DO);
+    pw.commit();
+
+    BOOST_CHECK_EQUAL(getZ(qname, qtype, qclass, query), EDNS_HEADER_FLAG_DO);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(query.data()), query.size()), 512);
+  }
+
+    {
+    /* valid EDNS, options, DO not set */
+    vector<uint8_t> query;
+    DNSPacketWriter pw(query, qname, qtype, qclass, 0);
+    pw.addOpt(512, 0, 0, opts);
+    pw.commit();
+
+    BOOST_CHECK_EQUAL(getZ(qname, qtype, qclass, query), 0);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(query.data()), query.size()), 512);
+  }
+
+  {
+    /* valid EDNS, options, DO set */
+    vector<uint8_t> query;
+    DNSPacketWriter pw(query, qname, qtype, qclass, 0);
+    pw.addOpt(512, 0, EDNS_HEADER_FLAG_DO, opts);
+    pw.commit();
+
+    BOOST_CHECK_EQUAL(getZ(qname, qtype, qclass, query), EDNS_HEADER_FLAG_DO);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(query.data()), query.size()), 512);
+  }
+
+}
+
+BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) {
+
+  DNSName qname("www.powerdns.com.");
+  uint16_t qtype = QType::A;
+  uint16_t qclass = QClass::IN;
+  EDNSSubnetOpts ecsOpts;
+  ecsOpts.source = Netmask(ComboAddress("127.0.0.1"), ECSSourcePrefixV4);
+  string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts);
+  EDNSCookiesOpt cookiesOpt;
+  cookiesOpt.client = string("deadbeef");
+  cookiesOpt.server = string("deadbeef");
+  string cookiesOptionStr = makeEDNSCookiesOptString(cookiesOpt);
+  DNSPacketWriter::optvect_t opts;
+  opts.push_back(make_pair(EDNSOptionCode::COOKIE, cookiesOptionStr));
+  opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOptionStr));
+  ComboAddress lc("127.0.0.1");
+  ComboAddress rem("127.0.0.1");
+  struct timespec queryRealTime;
+  gettime(&queryRealTime, true);
+
+  {
+    /* no EDNS */
+    vector<uint8_t> query;
+    DNSPacketWriter pw(query, qname, qtype, qclass, 0);
+    pw.getHeader()->qr = 1;
+    pw.getHeader()->rcode = RCode::NXDomain;
+    pw.commit();
+
+    auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query);
+    BOOST_CHECK_EQUAL(getEDNSZ(dq), 0);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(dq.dh), dq.len), 0);
+  }
+
+  {
+    /* truncated EDNS */
+    vector<uint8_t> query;
+    DNSPacketWriter pw(query, qname, qtype, qclass, 0);
+    pw.addOpt(512, 0, EDNS_HEADER_FLAG_DO);
+    pw.commit();
+
+    query.resize(query.size() - (/* RDLEN */ sizeof(uint16_t) + /* last byte of TTL / Z */ 1));
+    auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query);
+    BOOST_CHECK_EQUAL(getEDNSZ(dq), 0);
+    /* 512, because we don't touch a broken OPT record */
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(dq.dh), dq.len), 512);
+  }
+
+  {
+    /* valid EDNS, no options, DO not set */
+    vector<uint8_t> query;
+    DNSPacketWriter pw(query, qname, qtype, qclass, 0);
+    pw.addOpt(512, 0, 0);
+    pw.commit();
+
+    auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query);
+    BOOST_CHECK_EQUAL(getEDNSZ(dq), 0);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(dq.dh), dq.len), g_PayloadSizeSelfGenAnswers);
+  }
+
+  {
+    /* valid EDNS, no options, DO set */
+    vector<uint8_t> query;
+    DNSPacketWriter pw(query, qname, qtype, qclass, 0);
+    pw.addOpt(512, 0, EDNS_HEADER_FLAG_DO);
+    pw.commit();
+
+    auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query);
+    BOOST_CHECK_EQUAL(getEDNSZ(dq), EDNS_HEADER_FLAG_DO);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(dq.dh), dq.len), g_PayloadSizeSelfGenAnswers);
+  }
+
+  {
+    /* valid EDNS, options, DO not set */
+    vector<uint8_t> query;
+    DNSPacketWriter pw(query, qname, qtype, qclass, 0);
+    pw.addOpt(512, 0, 0, opts);
+    pw.commit();
+
+    auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query);
+    BOOST_CHECK_EQUAL(getEDNSZ(dq), 0);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(dq.dh), dq.len), g_PayloadSizeSelfGenAnswers);
+  }
+
+  {
+    /* valid EDNS, options, DO set */
+    vector<uint8_t> query;
+    DNSPacketWriter pw(query, qname, qtype, qclass, 0);
+    pw.addOpt(512, 0, EDNS_HEADER_FLAG_DO, opts);
+    pw.commit();
+
+    auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query);
+    BOOST_CHECK_EQUAL(getEDNSZ(dq), EDNS_HEADER_FLAG_DO);
+    BOOST_CHECK_EQUAL(getEDNSUDPPayloadSize(reinterpret_cast<const char*>(dq.dh), dq.len), g_PayloadSizeSelfGenAnswers);
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();
index 31c88c01e12e9a3a7e9268dfbe1e8f2a382f0fe1..eaf741e5087f33f8e7162eceee959fff254bef1c 100644 (file)
@@ -49,7 +49,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) {
       uint32_t key = 0;
       boost::optional<Netmask> subnet;
       auto dh = reinterpret_cast<dnsheader*>(query.data());
-      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
+      DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
       bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet);
       BOOST_CHECK_EQUAL(found, false);
       BOOST_CHECK(!subnet);
@@ -82,7 +82,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) {
       uint16_t responseBufSize = sizeof(responseBuf);
       uint32_t key = 0;
       boost::optional<Netmask> subnet;
-      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime);
+      DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime);
       bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet);
       if (found == true) {
         PC.expungeByName(a);
@@ -105,7 +105,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) {
       boost::optional<Netmask> subnet;
       char response[4096];
       uint16_t responseSize = sizeof(response);
-      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, (struct dnsheader*) query.data(), len, query.size(), false, &queryTime);
+      DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, (struct dnsheader*) query.data(), len, query.size(), false, &queryTime);
       if(PC.get(dq, a.wirelength(), pwQ.getHeader()->id, response, &responseSize, &key, subnet)) {
         matches++;
       }
@@ -151,7 +151,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheServFailTTL) {
     uint32_t key = 0;
     boost::optional<Netmask> subnet;
     auto dh = reinterpret_cast<dnsheader*>(query.data());
-    DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
+    DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
     bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet);
     BOOST_CHECK_EQUAL(found, false);
     BOOST_CHECK(!subnet);
@@ -208,7 +208,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNoDataTTL) {
     uint32_t key = 0;
     boost::optional<Netmask> subnet;
     auto dh = reinterpret_cast<dnsheader*>(query.data());
-    DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
+    DNSQuestion dq(&name, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
     bool found = PC.get(dq, name.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet);
     BOOST_CHECK_EQUAL(found, false);
     BOOST_CHECK(!subnet);
@@ -264,7 +264,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNXDomainTTL) {
     uint32_t key = 0;
     boost::optional<Netmask> subnet;
     auto dh = reinterpret_cast<dnsheader*>(query.data());
-    DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
+    DNSQuestion dq(&name, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
     bool found = PC.get(dq, name.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet);
     BOOST_CHECK_EQUAL(found, false);
     BOOST_CHECK(!subnet);
@@ -317,7 +317,7 @@ static void *threadMangler(void* off)
       uint32_t key = 0;
       boost::optional<Netmask> subnet;
       auto dh = reinterpret_cast<dnsheader*>(query.data());
-      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
+      DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime);
       PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet);
 
       PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none);
@@ -351,7 +351,7 @@ static void *threadReader(void* off)
       uint16_t responseBufSize = sizeof(responseBuf);
       uint32_t key = 0;
       boost::optional<Netmask> subnet;
-      DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime);
+      DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime);
       bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet);
       if (!found) {
        g_missing++;
@@ -422,7 +422,7 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) {
     ComboAddress remote("192.0.2.1");
     struct timespec queryTime;
     gettime(&queryTime);
-    DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, pwQ.getHeader(), query.size(), query.size(), false, &queryTime);
+    DNSQuestion dq(&qname, QType::AAAA, QClass::IN, 0, &remote, &remote, pwQ.getHeader(), query.size(), query.size(), false, &queryTime);
     bool found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &key, subnetOut);
     BOOST_CHECK_EQUAL(found, false);
     BOOST_REQUIRE(subnetOut);
@@ -467,7 +467,7 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) {
     ComboAddress remote("192.0.2.1");
     struct timespec queryTime;
     gettime(&queryTime);
-    DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, pwQ.getHeader(), query.size(), query.size(), false, &queryTime);
+    DNSQuestion dq(&qname, QType::AAAA, QClass::IN, 0, &remote, &remote, pwQ.getHeader(), query.size(), query.size(), false, &queryTime);
     bool found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &secondKey, subnetOut);
     BOOST_CHECK_EQUAL(found, false);
     BOOST_CHECK_EQUAL(secondKey, key);
index 51cf6ac02fd55ed79285b9a44faf1811ab01c326..f202b87bf0e66b774550ec723fcb163c7d2ada64 100644 (file)
@@ -479,6 +479,10 @@ class DNSDistTest(unittest.TestCase):
         self.assertEquals(received.edns, -1)
         self.assertEquals(len(received.options), 0)
 
+    def checkMessageEDNSWithoutOptions(self, expected, received):
+        self.assertEquals(expected, received)
+        self.assertEquals(received.edns, 0)
+
     def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
         self.assertEquals(expected, received)
         self.assertEquals(received.edns, 0)
diff --git a/regression-tests.dnsdist/test_EDNSSelfGenerated.py b/regression-tests.dnsdist/test_EDNSSelfGenerated.py
new file mode 100644 (file)
index 0000000..7704641
--- /dev/null
@@ -0,0 +1,373 @@
+#!/usr/bin/env python
+import dns
+import clientsubnetoption
+from dnsdisttests import DNSDistTest
+from datetime import datetime, timedelta
+
+class TestEDNSSelfGenerated(DNSDistTest):
+    """
+    Check that dnsdist sends correct EDNS data on
+    self-generated (RCodeAction(), TCAction(), Lua..)
+    """
+
+    _config_template = """
+    addAction("rcode.edns-self.tests.powerdns.com.", RCodeAction(dnsdist.REFUSED))
+    addAction("tc.edns-self.tests.powerdns.com.", TCAction())
+
+    function luarule(dq)
+      return DNSAction.Nxdomain, ""
+    end
+
+    addLuaAction("lua.edns-self.tests.powerdns.com.", luarule)
+
+    addAction("spoof.edns-self.tests.powerdns.com.", SpoofAction({'192.0.2.1', '192.0.2.2'}))
+
+    setPayloadSizeOnSelfGeneratedAnswers(1042)
+
+    newServer{address="127.0.0.1:%s"}
+    """
+
+    def testNoEDNS(self):
+        """
+        EDNS on Self-Generated: No existing EDNS
+        """
+        name = 'no-edns.rcode.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+        name = 'no-edns.tc.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.flags |= dns.flags.TC
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+        name = 'no-edns.lua.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+        name = 'no-edns.spoof.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.answer.append(dns.rrset.from_text(name,
+                                                           60,
+                                                           dns.rdataclass.IN,
+                                                           dns.rdatatype.A,
+                                                           '192.0.2.1', '192.0.2.2'))
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+    def testWithEDNSNoDO(self):
+        """
+        EDNS on Self-Generated: EDNS with DO=0
+        """
+        name = 'edns-no-do.rcode.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        name = 'edns-no-do.tc.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.flags |= dns.flags.TC
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        name = 'edns-no-do.lua.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        name = 'edns-no-do.spoof.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.answer.append(dns.rrset.from_text(name,
+                                                           60,
+                                                           dns.rdataclass.IN,
+                                                           dns.rdatatype.A,
+                                                           '192.0.2.1', '192.0.2.2'))
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+    def testWithEDNSWithDO(self):
+        """
+        EDNS on Self-Generated: EDNS with DO=1
+        """
+        name = 'edns-do.rcode.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=True)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        name = 'edns-do.tc.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=True)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.flags |= dns.flags.TC
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        name = 'edns-do.lua.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=True)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        name = 'edns-do.spoof.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=True)
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.answer.append(dns.rrset.from_text(name,
+                                                           60,
+                                                           dns.rdataclass.IN,
+                                                           dns.rdatatype.A,
+                                                           '192.0.2.1', '192.0.2.2'))
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+    def testWithEDNSNoOptions(self):
+        """
+        EDNS on Self-Generated: EDNS with options in the query
+        """
+        name = 'edns-options.rcode.edns-self.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512, want_dnssec=True)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        name = 'edns-options.tc.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512, want_dnssec=True)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.flags |= dns.flags.TC
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        name = 'edns-options.lua.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512, want_dnssec=True)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        name = 'edns-options.spoof.edns-self.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512, want_dnssec=True)
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.answer.append(dns.rrset.from_text(name,
+                                                           60,
+                                                           dns.rdataclass.IN,
+                                                           dns.rdatatype.A,
+                                                           '192.0.2.1', '192.0.2.2'))
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+        self.assertEquals(receivedResponse.payload, 1042)
+
+
+class TestEDNSSelfGeneratedDisabled(DNSDistTest):
+    """
+    Check that dnsdist does not send EDNS data on
+    self-generated (RCodeAction(), TCAction(), Lua..) when disabled
+    """
+
+    _config_template = """
+    setAddEDNSToSelfGeneratedResponses(false)
+
+    addAction("rcode.edns-self-disabled.tests.powerdns.com.", RCodeAction(dnsdist.REFUSED))
+    addAction("tc.edns-self-disabled.tests.powerdns.com.", TCAction())
+
+    function luarule(dq)
+      return DNSAction.Nxdomain, ""
+    end
+
+    addLuaAction("lua.edns-self-disabled.tests.powerdns.com.", luarule)
+
+    addAction("spoof.edns-self-disabled.tests.powerdns.com.", SpoofAction({'192.0.2.1', '192.0.2.2'}))
+
+    setPayloadSizeOnSelfGeneratedAnswers(1042)
+
+    newServer{address="127.0.0.1:%s"}
+    """
+
+    def testWithEDNSNoDO(self):
+        """
+        EDNS on Self-Generated (disabled): EDNS with DO=0
+        """
+        name = 'edns-no-do.rcode.edns-self-disabled.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+        name = 'edns-no-do.tc.edns-self-disabled.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.flags |= dns.flags.TC
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+        name = 'edns-no-do.lua.edns-self-disabled.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+        name = 'edns-no-do.spoof.edns-self-disabled.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.answer.append(dns.rrset.from_text(name,
+                                                           60,
+                                                           dns.rdataclass.IN,
+                                                           dns.rdatatype.A,
+                                                           '192.0.2.1', '192.0.2.2'))
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.checkMessageNoEDNS(expectedResponse, receivedResponse)