]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Allow altering the ECS behavior via rules and Lua 4519/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 29 Sep 2016 08:48:04 +0000 (10:48 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 29 Sep 2016 08:48:04 +0000 (10:48 +0200)
12 files changed:
pdns/README-dnsdist.md
pdns/dnsdist-ecs.cc
pdns/dnsdist-ecs.hh
pdns/dnsdist-lua.cc
pdns/dnsdist-lua2.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/dnsrulactions.cc
pdns/dnsrulactions.hh
pdns/test-dnsdist_cc.cc
regression-tests.dnsdist/test_EdnsClientSubnet.py

index ef2246d72f0728f303839b0a0dd0853951520407..fb39794e69b5a8453dae73292a39998118106c80 100644 (file)
@@ -167,6 +167,17 @@ be 192.0.2.0. This can be changed with:
 > setECSSourcePrefixV6(56)
 ```
 
+In addition to the global settings, rules and Lua bindings can alter this behavior per query:
+
+* calling `DisableECSAction()` or setting `dq.useECS` to false prevent the sending of the ECS option
+* calling `ECSOverrideAction(bool)` or setting `dq.ecsOverride` will override the global `setECSOverride()` value
+* calling `ECSPrefixLengthAction(v4, v6)` or setting `dq.ecsPrefixLength` will override the global
+`setECSSourcePrefixV4()` and `setECSSourcePrefixV6()` values
+
+In effect this means that for the EDNS Client Subnet option to be added to the request, `useClientSubnet`
+should be set to true for the backend used (default to false) and ECS should not have been disabled by calling
+`DisableECSAction()` or setting `dq.useECS` to false (default to true).
+
 TCP timeouts
 ------------
 
@@ -362,6 +373,7 @@ Current actions are:
  * Add the source MAC address to the query (MacAddrAction)
  * Skip the cache, if any
  * Log query content to a remote server (RemoteLogAction)
+ * Alter the EDNS Client Subnet parameters (DisableECSAction, ECSOverrideAction, ECSPrefixLengthAction)
 
 Current response actions are:
 
@@ -415,7 +427,10 @@ A DNS rule can be:
 Some specific actions do not stop the processing when they match, contrary to all other actions:
 
  * Delay
+ * DisableECS
  * Disable Validation
+ * ECSOverride
+ * ECSPrefixLength
  * Log
  * MacAddr
  * No Recurse
@@ -1312,9 +1327,12 @@ instantiate a server with additional parameters
     * `AllowResponseAction()`: let these packets go through
     * `DelayAction(milliseconds)`: delay the response by the specified amount of milliseconds (UDP-only)
     * `DelayResponseAction(milliseconds)`: delay the response by the specified amount of milliseconds (UDP-only)
+    * `DisableECSAction()`: disable the sending of ECS to the backend
     * `DisableValidationAction()`: set the CD bit in the question, let it go through
     * `DropAction()`: drop these packets
     * `DropResponseAction()`: drop these packets
+    * `ECSOverrideAction(bool)`: whether an existing ECS value should be overriden (true) or not (false)
+    * `ECSPrefixLengthAction(v4, v6)`: set the ECS prefix length
     * `LogAction([filename], [binary], [append], [buffered])`: Log a line for each query, to the specified file if any, to the console (require verbose) otherwise. When logging to a file, the `binary` optional parameter specifies whether we log in binary form (default) or in textual form, the `append` optional parameter specifies whether we open the file for appending or truncate each time (default), and the `buffered` optional parameter specifies whether writes to the file are buffered (default) or not.
     * `NoRecurseAction()`: strip RD bit from the question, let it go through
     * `PoolAction(poolname)`: set the packet into the specified pool
@@ -1420,6 +1438,8 @@ instantiate a server with additional parameters
         * member `wirelength()`: return the length on the wire
     * DNSQuestion related:
         * member `dh`: DNSHeader
+        * member `ecsOverride`: whether an existing ECS value should be overriden (settable)
+        * member `ecsPrefixLength`: the ECS prefix length to use (settable)
         * member `len`: the question length
         * member `localaddr`: ComboAddress of the local bind this question was received on
         * member `opcode`: the question opcode
@@ -1431,6 +1451,7 @@ instantiate a server with additional parameters
         * member `size`: the total size of the buffer starting at `dh`
         * member `skipCache`: whether to skip cache lookup / storing the answer for this question (settable)
         * member `tcp`: whether this question was received over a TCP socket
+        * member `useECS`: whether to send ECS to the backend (settable)
     * DNSHeader related
         * member `getRD()`: get recursion desired flag
         * member `setRD(bool)`: set recursion desired flag
index d78821ed45cdf1d253aa0bf67230c119732aee24..bab9e3fea4395ba46e25cf8ebb84e74fba01fa30 100644 (file)
@@ -29,7 +29,7 @@
 
 /* when we add EDNS to a query, we don't want to advertise
    a large buffer size */
-size_t q_EdnsUDPPayloadSize = 512;
+size_t g_EdnsUDPPayloadSize = 512;
 /* draft-ietf-dnsop-edns-client-subnet-04 "11.1.  Privacy" */
 uint16_t g_ECSSourcePrefixV4 = 24;
 uint16_t g_ECSSourcePrefixV6 = 56;
@@ -231,9 +231,9 @@ static int getEDNSOptionsStart(char* packet, const size_t offset, const size_t l
   return 0;
 }
 
-static void generateECSOption(const ComboAddress& source, string& res)
+static void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength)
 {
-  Netmask sourceNetmask(source, source.sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6);
+  Netmask sourceNetmask(source, ECSPrefixLength);
   EDNSSubnetOpts ecsOpts;
   ecsOpts.source = sourceNetmask;
   string payload = makeEDNSSubnetOptsString(ecsOpts);
@@ -250,7 +250,7 @@ void generateOptRR(const std::string& optRData, string& res)
   edns0.Z = 0;
   
   dh.d_type = htons(QType::OPT);
-  dh.d_class = htons(q_EdnsUDPPayloadSize);
+  dh.d_class = htons(g_EdnsUDPPayloadSize);
   static_assert(sizeof(EDNS0Record) == sizeof(dh.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
   memcpy(&dh.d_ttl, &edns0, sizeof edns0);
   dh.d_clen = htons((uint16_t) optRData.length());
@@ -259,14 +259,14 @@ void generateOptRR(const std::string& optRData, string& res)
   res.append(optRData.c_str(), optRData.length());
 }
 
-static void replaceEDNSClientSubnetOption(char * const packet, const size_t packetSize, uint16_t * const len, string& largerPacket, const ComboAddress& remote, char * const oldEcsOptionStart, size_t const oldEcsOptionSize, unsigned char * const optRDLen)
+static void replaceEDNSClientSubnetOption(char * const packet, const size_t packetSize, uint16_t * const len, string& largerPacket, const ComboAddress& remote, char * const oldEcsOptionStart, size_t const oldEcsOptionSize, unsigned char * const optRDLen, uint16_t ECSPrefixLength)
 {
   assert(packet != NULL);
   assert(len != NULL);
   assert(oldEcsOptionStart != NULL);
   assert(optRDLen != NULL);
   string ECSOption;
-  generateECSOption(remote, ECSOption);
+  generateECSOption(remote, ECSOption, ECSPrefixLength);
 
   if (ECSOption.size() == oldEcsOptionSize) {
     /* same size as the existing option */
@@ -309,7 +309,7 @@ static void replaceEDNSClientSubnetOption(char * const packet, const size_t pack
   }
 }
 
-void handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, string& largerPacket, bool* const ednsAdded, bool* const ecsAdded, const ComboAddress& remote)
+void handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, string& largerPacket, bool* const ednsAdded, bool* const ecsAdded, const ComboAddress& remote, bool overrideExisting, uint16_t ecsPrefixLength)
 {
   assert(packet != NULL);
   assert(len != NULL);
@@ -318,7 +318,7 @@ void handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u
   assert(ecsAdded != NULL);
   unsigned char * optRDLen = NULL;
   size_t remaining = 0;
-        
+
   int res = getEDNSOptionsStart(packet, consumed, *len, (char**) &optRDLen, &remaining);
         
   if (res == 0) {
@@ -329,15 +329,15 @@ void handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u
     
     if (res == 0) {
       /* there is already an ECS value */
-      if (g_ECSOverride) {
-        replaceEDNSClientSubnetOption(packet, packetSize, len, largerPacket, remote, ecsOptionStart, ecsOptionSize, optRDLen);
+      if (overrideExisting) {
+        replaceEDNSClientSubnetOption(packet, packetSize, len, largerPacket, remote, ecsOptionStart, ecsOptionSize, optRDLen, ecsPrefixLength);
       }
     } else {
       /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
       /* getEDNSOptionsStart has already checked that there is exactly one AR,
          no NS and no AN */
       string ECSOption;
-      generateECSOption(remote, ECSOption);
+      generateECSOption(remote, ECSOption, ecsPrefixLength);
       const size_t ECSOptionSize = ECSOption.size();
       
       uint16_t newRDLen = (optRDLen[0] * 256) + optRDLen[1];
@@ -366,7 +366,7 @@ void handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u
     string EDNSRR;
     struct dnsheader* dh = (struct dnsheader*) packet;
     string optRData;
-    generateECSOption(remote, optRData);
+    generateECSOption(remote, optRData, ecsPrefixLength);
     generateOptRR(optRData, EDNSRR);
     uint16_t arcount = ntohs(dh->arcount);
     arcount++;
index 5f197c000d8793b5efb861f553d66d601a0981c1..84013eb05b63355225928d6ce0bf0e41f544306b 100644 (file)
@@ -23,7 +23,7 @@
 
 int rewriteResponseWithoutEDNS(const char * packet, size_t len, vector<uint8_t>& newContent);
 int locateEDNSOptRR(char * packet, size_t len, char ** optStart, size_t * optLen, bool * last);
-void handleEDNSClientSubnet(char * packet, size_t packetSize, unsigned int consumed, uint16_t * len, string& largerPacket, bool* ednsAdded, bool* ecsAdded, const ComboAddress& remote);
+void handleEDNSClientSubnet(char * packet, size_t packetSize, unsigned int consumed, uint16_t * len, string& largerPacket, bool* ednsAdded, bool* ecsAdded, const ComboAddress& remote, bool overrideExisting, uint16_t ecsPrefixLength);
 void generateOptRR(const std::string& optRData, string& res);
 int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove);
 int rewriteResponseWithoutEDNSOption(const char * packet, const size_t len, const uint16_t optionCodeToSkip, vector<uint8_t>& newContent);
index fcb8a590ade8a7bd626cccd25183a7f77522edec..c5e90599484ed87caddf3bd9f0d3c7d30681c9d7 100644 (file)
@@ -1481,6 +1481,9 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
   g_lua.registerMember<size_t (DNSQuestion::*)>("size", [](const DNSQuestion& dq) -> size_t { return dq.size; }, [](DNSQuestion& dq, size_t newSize) { (void) newSize; });
   g_lua.registerMember<bool (DNSQuestion::*)>("tcp", [](const DNSQuestion& dq) -> bool { return dq.tcp; }, [](DNSQuestion& dq, bool newTcp) { (void) newTcp; });
   g_lua.registerMember<bool (DNSQuestion::*)>("skipCache", [](const DNSQuestion& dq) -> bool { return dq.skipCache; }, [](DNSQuestion& dq, bool newSkipCache) { dq.skipCache = newSkipCache; });
+  g_lua.registerMember<bool (DNSQuestion::*)>("useECS", [](const DNSQuestion& dq) -> bool { return dq.useECS; }, [](DNSQuestion& dq, bool useECS) { dq.useECS = useECS; });
+  g_lua.registerMember<bool (DNSQuestion::*)>("ecsOverride", [](const DNSQuestion& dq) -> bool { return dq.ecsOverride; }, [](DNSQuestion& dq, bool ecsOverride) { dq.ecsOverride = ecsOverride; });
+  g_lua.registerMember<uint16_t (DNSQuestion::*)>("ecsPrefixLength", [](const DNSQuestion& dq) -> uint16_t { return dq.ecsPrefixLength; }, [](DNSQuestion& dq, uint16_t newPrefixLength) { dq.ecsPrefixLength = newPrefixLength; });
 
   g_lua.writeFunction("setMaxTCPClientThreads", [](uint64_t max) {
       if (!g_configurationDone) {
index 7bd9f8826c48ab076b8c972d4515920045269fad..7afe7c0f87c337c040a466f0f4f6c4f9a931f441 100644 (file)
@@ -743,6 +743,18 @@ void moreLua(bool client)
         return std::shared_ptr<DNSAction>(new TeeAction(ComboAddress(remote, 53), addECS ? *addECS : false));
       });
 
+    g_lua.writeFunction("ECSPrefixLengthAction", [](uint16_t v4PrefixLength, uint16_t v6PrefixLength) {
+        return std::shared_ptr<DNSAction>(new ECSPrefixLengthAction(v4PrefixLength, v6PrefixLength));
+      });
+
+    g_lua.writeFunction("ECSOverrideAction", [](bool ecsOverride) {
+        return std::shared_ptr<DNSAction>(new ECSOverrideAction(ecsOverride));
+      });
+
+    g_lua.writeFunction("DisableECSAction", []() {
+        return std::shared_ptr<DNSAction>(new DisableECSAction());
+      });
+
     g_lua.registerFunction<void(DNSAction::*)()>("printStats", [](const DNSAction& ta) {
         setLuaNoSideEffect();
         auto stats = ta.getStats();
index e57f8064f22a58fc5f48d11e6b7267d666862628..21f43570d5e5d086aab7ce8010a7a6ce7064131d 100644 (file)
@@ -292,9 +292,9 @@ void* tcpClientThread(int pipefd)
          packetCache = serverPool->packetCache;
        }
 
-        if (ds && ds->useECS) {
+        if (dq.useECS && ds && ds->useECS) {
           uint16_t newLen = dq.len;
-          handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, &ecsAdded, ci.remote);
+          handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, &ecsAdded, ci.remote, dq.ecsOverride, dq.ecsPrefixLength);
           if (largerQuery.empty() == false) {
             query = largerQuery.c_str();
             dq.len = (uint16_t) largerQuery.size();
index f4fcd72f3360a7ef220d27c7419a827ea2fc9bd1..847cdcacf3ef0b4fea1ba5e9cae5c2ab6759702e 100644 (file)
@@ -1040,8 +1040,8 @@ try
 
       bool ednsAdded = false;
       bool ecsAdded = false;
-      if (ss && ss->useECS) {
-        handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, largerQuery, &(ednsAdded), &(ecsAdded), remote);
+      if (dq.useECS && ss && ss->useECS) {
+        handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, largerQuery, &(ednsAdded), &(ecsAdded), remote, dq.ecsOverride, dq.ecsPrefixLength);
       }
 
       uint32_t cacheKey = 0;
index b494e37a02f0d6cc1b879d45b0999957282f7e7d..e6132997e96a87dfe18597a14f07776328ec0c3e 100644 (file)
@@ -444,9 +444,13 @@ struct DownstreamState
 };
 using servers_t =vector<std::shared_ptr<DownstreamState>>;
 
+extern uint16_t g_ECSSourcePrefixV4;
+extern uint16_t g_ECSSourcePrefixV6;
+extern bool g_ECSOverride;
+
 struct DNSQuestion
 {
-  DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t queryLen, bool isTcp): qname(name), qtype(type), qclass(class_), local(lc), remote(rem), dh(header), size(bufferSize), len(queryLen), tcp(isTcp) { }
+  DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t queryLen, bool isTcp): qname(name), qtype(type), qclass(class_), local(lc), remote(rem), dh(header), size(bufferSize), len(queryLen), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), tcp(isTcp), ecsOverride(g_ECSOverride) { }
 
 #ifdef HAVE_PROTOBUF
   boost::uuids::uuid uniqueId;
@@ -459,8 +463,11 @@ struct DNSQuestion
   struct dnsheader* dh;
   size_t size;
   uint16_t len;
+  uint16_t ecsPrefixLength;
   const bool tcp;
   bool skipCache{false};
+  bool ecsOverride;
+  bool useECS{true};    
 };
 
 struct DNSResponse : DNSQuestion
@@ -665,9 +672,6 @@ extern std::atomic<bool> g_configurationDone;
 extern uint64_t g_maxTCPClientThreads;
 extern uint64_t g_maxTCPQueuedConnections;
 extern std::atomic<uint16_t> g_cacheCleaningDelay;
-extern uint16_t g_ECSSourcePrefixV4;
-extern uint16_t g_ECSSourcePrefixV6;
-extern bool g_ECSOverride;
 extern bool g_verboseHealthChecks;
 extern uint32_t g_staleCacheEntriesTTL;
 
index 8518a7197d0ceb01fb0d9131b55cdfbb6fed7e19..49292763e2122921f9fa50d2f1e115bd499af414 100644 (file)
@@ -56,7 +56,7 @@ DNSAction::Action TeeAction::operator()(DNSQuestion* dq, string* ruleresult) con
       query.reserve(dq->size);
       query.assign((char*) dq->dh, len);
 
-      handleEDNSClientSubnet((char*) query.c_str(), query.size(), dq->qname->wirelength(), &len, larger, &ednsAdded, &ecsAdded, *dq->remote);
+      handleEDNSClientSubnet((char*) query.c_str(), query.size(), dq->qname->wirelength(), &len, larger, &ednsAdded, &ecsAdded, *dq->remote, dq->ecsOverride, dq->ecsPrefixLength);
 
       if (larger.empty()) {
         res = send(d_fd, query.c_str(), len, 0);
index 69a7564c48eec7266704075d13c42331090b6567..0244cd6ea97773bcdc6315c5d5734bf9b6087fa0 100644 (file)
@@ -1001,6 +1001,60 @@ public:
   }
 };
 
+class ECSPrefixLengthAction : public DNSAction
+{
+public:
+  ECSPrefixLengthAction(uint16_t v4Length, uint16_t v6Length) : d_v4PrefixLength(v4Length), d_v6PrefixLength(v6Length)
+  {
+  }
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
+  {
+    dq->ecsPrefixLength = dq->remote->sin4.sin_family == AF_INET ? d_v4PrefixLength : d_v6PrefixLength;
+    return Action::None;
+  }
+  string toString() const override
+  {
+    return "set ECS prefix length to " + std::to_string(d_v4PrefixLength) + "/" + std::to_string(d_v6PrefixLength);
+  }
+private:
+  uint16_t d_v4PrefixLength;
+  uint16_t d_v6PrefixLength;
+};
+
+class ECSOverrideAction : public DNSAction
+{
+public:
+  ECSOverrideAction(bool ecsOverride) : d_ecsOverride(ecsOverride)
+  {
+  }
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
+  {
+    dq->ecsOverride = d_ecsOverride;
+    return Action::None;
+  }
+  string toString() const override
+  {
+    return "set ECS override to " + std::to_string(d_ecsOverride);
+  }
+private:
+  bool d_ecsOverride;
+};
+
+
+class DisableECSAction : public DNSAction
+{
+public:
+  DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
+  {
+    dq->useECS = false;
+    return Action::None;
+  }
+  string toString() const override
+  {
+    return "disable ECS";
+  }
+};
+
 class RemoteLogAction : public DNSAction, public boost::noncopyable
 {
 public:
index fe8a8ff36061c77cb37343739c06055c7e466b48..7b17c4408de79f1c332f306a8645effa3ca26ebe 100644 (file)
@@ -43,6 +43,9 @@ bool g_console{true};
 bool g_syslog{true};
 bool g_verbose{true};
 
+static const uint16_t ECSSourcePrefixV4 = 24;
+static const uint16_t ECSSourcePrefixV6 = 56;
+
 static void validateQuery(const char * packet, size_t packetSize)
 {
   MOADNSParser mdp(packet, packetSize);
@@ -91,7 +94,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS)
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote);
+  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
   BOOST_CHECK((size_t) len > query.size());
   BOOST_CHECK_EQUAL(largerPacket.size(), 0);
   BOOST_CHECK_EQUAL(ednsAdded, true);
@@ -105,7 +108,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS)
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote);
+  handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
   BOOST_CHECK_EQUAL((size_t) len, query.size());
   BOOST_CHECK(largerPacket.size() > query.size());
   BOOST_CHECK_EQUAL(ednsAdded, true);
@@ -137,7 +140,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote);
+  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
   BOOST_CHECK((size_t) len > query.size());
   BOOST_CHECK_EQUAL(largerPacket.size(), 0);
   BOOST_CHECK_EQUAL(ednsAdded, false);
@@ -151,7 +154,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote);
+  handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
   BOOST_CHECK_EQUAL((size_t) len, query.size());
   BOOST_CHECK(largerPacket.size() > query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
@@ -171,7 +174,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) {
   DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
   pw.getHeader()->rd = 1;
   EDNSSubnetOpts ecsOpts;
-  ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4);
+  ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4);
   string origECSOption = makeEDNSSubnetOptsString(ecsOpts);
   DNSPacketWriter::optvect_t opts;
   opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOption));
@@ -189,8 +192,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  g_ECSOverride = true;
-  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote);
+  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
   BOOST_CHECK_EQUAL((size_t) len, query.size());
   BOOST_CHECK_EQUAL(largerPacket.size(), 0);
   BOOST_CHECK_EQUAL(ednsAdded, false);
@@ -228,8 +230,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSmaller) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  g_ECSOverride = true;
-  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote);
+  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
   BOOST_CHECK((size_t) len < query.size());
   BOOST_CHECK_EQUAL(largerPacket.size(), 0);
   BOOST_CHECK_EQUAL(ednsAdded, false);
@@ -267,8 +268,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  g_ECSOverride = true;
-  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote);
+  handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
   BOOST_CHECK((size_t) len > query.size());
   BOOST_CHECK_EQUAL(largerPacket.size(), 0);
   BOOST_CHECK_EQUAL(ednsAdded, false);
@@ -282,8 +282,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  g_ECSOverride = true;
-  handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote);
+  handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
   BOOST_CHECK_EQUAL((size_t) len, query.size());
   BOOST_CHECK(largerPacket.size() > query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
@@ -398,7 +397,7 @@ BOOST_AUTO_TEST_CASE(removeECSWhenOnlyOption) {
   pw.commit();
 
   EDNSSubnetOpts ecsOpts;
-  ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4);
+  ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4);
   string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts);
   DNSPacketWriter::optvect_t opts;
   opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOptionStr));
@@ -445,7 +444,7 @@ BOOST_AUTO_TEST_CASE(removeECSWhenFirstOption) {
   pw.commit();
 
   EDNSSubnetOpts ecsOpts;
-  ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4);
+  ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV6);
   string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts);
   EDNSCookiesOpt cookiesOpt;
   cookiesOpt.client = string("deadbeef");
@@ -497,7 +496,7 @@ BOOST_AUTO_TEST_CASE(removeECSWhenIntermediaryOption) {
   pw.commit();
 
   EDNSSubnetOpts ecsOpts;
-  ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4);
+  ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4);
   string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts);
 
   EDNSCookiesOpt cookiesOpt;
@@ -557,7 +556,7 @@ BOOST_AUTO_TEST_CASE(removeECSWhenLastOption) {
   cookiesOpt.server = string("deadbeef");
   string cookiesOptionStr = makeEDNSCookiesOptString(cookiesOpt);
   EDNSSubnetOpts ecsOpts;
-  ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4);
+  ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4);
   string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts);
   DNSPacketWriter::optvect_t opts;
   opts.push_back(make_pair(EDNSOptionCode::COOKIE, cookiesOptionStr));
@@ -601,7 +600,7 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenOnlyOption) {
   pw.xfr32BitInt(0x01020304);
 
   EDNSSubnetOpts ecsOpts;
-  ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4);
+  ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4);
   string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts);
   DNSPacketWriter::optvect_t opts;
   opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOptionStr));
@@ -638,7 +637,7 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenFirstOption) {
   pw.xfr32BitInt(0x01020304);
 
   EDNSSubnetOpts ecsOpts;
-  ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4);
+  ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4);
   string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts);
   EDNSCookiesOpt cookiesOpt;
   cookiesOpt.client = string("deadbeef");
@@ -680,7 +679,7 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenIntermediaryOption) {
   pw.xfr32BitInt(0x01020304);
 
   EDNSSubnetOpts ecsOpts;
-  ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4);
+  ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4);
   string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts);
   EDNSCookiesOpt cookiesOpt;
   cookiesOpt.client = string("deadbeef");
@@ -724,7 +723,7 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenLastOption) {
   pw.xfr32BitInt(0x01020304);
 
   EDNSSubnetOpts ecsOpts;
-  ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4);
+  ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4);
   string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts);
   EDNSCookiesOpt cookiesOpt;
   cookiesOpt.client = string("deadbeef");
index 273a7dee560f69481a9dc32efcd54bd6b0f5fe6a..701fe01929961837bb475c7d2a03d847bf761998 100644 (file)
@@ -4,7 +4,51 @@ import clientsubnetoption
 import cookiesoption
 from dnsdisttests import DNSDistTest
 
-class TestEdnsClientSubnetNoOverride(DNSDistTest):
+class TestEdnsClientSubnet(DNSDistTest):
+    def compareOptions(self, a, b):
+        self.assertEquals(len(a), len(b))
+        for idx in xrange(len(a)):
+            self.assertEquals(a[idx], b[idx])
+
+    def checkMessageNoEDNS(self, expected, received):
+        self.assertEquals(expected, received)
+        self.assertEquals(received.edns, -1)
+        self.assertEquals(len(received.options), 0)
+
+    def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
+        self.assertEquals(expected, received)
+        self.assertEquals(received.edns, 0)
+        self.assertEquals(len(received.options), withCookies)
+        if withCookies:
+            for option in received.options:
+                self.assertEquals(option.otype, 10)
+
+    def checkMessageEDNSWithECS(self, expected, received):
+        self.assertEquals(expected, received)
+        self.assertEquals(received.edns, 0)
+        self.assertEquals(len(received.options), 1)
+        self.assertEquals(received.options[0].otype, clientsubnetoption.ASSIGNED_OPTION_CODE)
+        self.compareOptions(expected.options, received.options)
+
+    def checkQueryEDNSWithECS(self, expected, received):
+        self.checkMessageEDNSWithECS(expected, received)
+
+    def checkResponseEDNSWithECS(self, expected, received):
+        self.checkMessageEDNSWithECS(expected, received)
+
+    def checkQueryEDNSWithoutECS(self, expected, received):
+        self.checkMessageEDNSWithoutECS(expected, received)
+
+    def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
+        self.checkMessageEDNSWithoutECS(expected, received, withCookies)
+
+    def checkQueryNoEDNS(self, expected, received):
+        self.checkMessageNoEDNS(expected, received)
+
+    def checkResponseNoEDNS(self, expected, received):
+        self.checkMessageNoEDNS(expected, received)
+
+class TestEdnsClientSubnetNoOverride(TestEdnsClientSubnet):
     """
     dnsdist is configured to add the EDNS0 Client Subnet
     option, but only if it's not already present in the
@@ -43,19 +87,15 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, -1)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, -1)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithEDNSNoECS(self):
         """
@@ -84,19 +124,15 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
 
     def testWithEDNSECS(self):
         """
@@ -119,23 +155,20 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
+
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(query, receivedQuery)
+        self.checkResponseEDNSWithoutECS(response, receivedResponse)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(query, receivedQuery)
+        self.checkResponseEDNSWithoutECS(response, receivedResponse)
 
     def testWithoutEDNSResponseWithECS(self):
         """
@@ -168,19 +201,15 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, -1)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, -1)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithEDNSNoECSResponseWithECS(self):
         """
@@ -213,19 +242,15 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
 
     def testWithEDNSNoECSResponseWithCookiesThenECS(self):
         """
@@ -254,24 +279,21 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
+        expectedResponse.use_edns(edns=True, payload=4096, options=[ecoResponse])
 
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 1)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 1)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
 
     def testWithEDNSNoECSResponseWithECSThenCookies(self):
         """
@@ -300,24 +322,21 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
+        response.use_edns(edns=True, payload=4096, options=[ecoResponse])
 
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 1)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 1)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
 
     def testWithEDNSNoECSResponseWithCookiesThenECSThenCookies(self):
         """
@@ -351,22 +370,18 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 2)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=2)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 2)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=2)
 
 
-class TestEdnsClientSubnetOverride(DNSDistTest):
+class TestEdnsClientSubnetOverride(TestEdnsClientSubnet):
     """
     dnsdist is configured to add the EDNS0 Client Subnet
     option, overwriting any existing value.
@@ -389,37 +404,34 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
         and that the response received from dnsdist does not
         have an EDNS pseudo-RR.
         """
-        name = 'withoutedns.overriden.ecs.tests.powerdns.com.'
+        name = 'withoutedns.overridden.ecs.tests.powerdns.com.'
         ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
         query = dns.message.make_query(name, 'A', 'IN')
         expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512)
         response = dns.message.make_response(expectedQuery)
-        expectedResponse = dns.message.make_response(query)
+        response.use_edns(edns=True, payload=4096, options=[ecso])
         rrset = dns.rrset.from_text(name,
                                     3600,
                                     dns.rdataclass.IN,
                                     dns.rdatatype.A,
                                     '127.0.0.1')
         response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, -1)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, -1)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithEDNSNoECS(self):
         """
@@ -430,37 +442,34 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
         has a valid ECS value and that the response
         received from dnsdist contains an EDNS pseudo-RR.
         """
-        name = 'withednsnoecs.overriden.ecs.tests.powerdns.com.'
+        name = 'withednsnoecs.overridden.ecs.tests.powerdns.com.'
         ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
         expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
         response = dns.message.make_response(expectedQuery)
-        expectedResponse = dns.message.make_response(query)
+        response.use_edns(edns=True, payload=4096, options=[ecso])
         rrset = dns.rrset.from_text(name,
                                     3600,
                                     dns.rdataclass.IN,
                                     dns.rdatatype.A,
                                     '127.0.0.1')
         response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
 
     def testWithEDNSShorterInitialECS(self):
         """
@@ -471,15 +480,16 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
         has an overwritten ECS value (not the initial one)
         and that the response received from dnsdist contains
         an EDNS pseudo-RR.
-        The initial ECS value is shorter than the one it will
+        The initial ECS value is shorter than the one it will be
         replaced with.
         """
-        name = 'withednsecs.overriden.ecs.tests.powerdns.com.'
+        name = 'withednsecs.overridden.ecs.tests.powerdns.com.'
         ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 8)
         rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
         expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[rewrittenEcso])
         response = dns.message.make_response(query)
+        response.use_edns(edns=True, payload=4096, options=[rewrittenEcso])
         rrset = dns.rrset.from_text(name,
                                     3600,
                                     dns.rdataclass.IN,
@@ -491,19 +501,15 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testWithEDNSLongerInitialECS(self):
         """
@@ -517,12 +523,13 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
         The initial ECS value is longer than the one it will
         replaced with.
         """
-        name = 'withednsecs.overriden.ecs.tests.powerdns.com.'
+        name = 'withednsecs.overridden.ecs.tests.powerdns.com.'
         ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 32)
         rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
         expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[rewrittenEcso])
         response = dns.message.make_response(query)
+        response.use_edns(edns=True, payload=4096, options=[rewrittenEcso])
         rrset = dns.rrset.from_text(name,
                                     3600,
                                     dns.rdataclass.IN,
@@ -534,19 +541,15 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testWithEDNSSameSizeInitialECS(self):
         """
@@ -560,12 +563,235 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
         The initial ECS value is exactly the same size as
         the one it will replaced with.
         """
-        name = 'withednsecs.overriden.ecs.tests.powerdns.com.'
+        name = 'withednsecs.overridden.ecs.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 24)
+        rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[rewrittenEcso])
+        response = dns.message.make_response(query)
+        response.use_edns(edns=True, payload=4096, options=[rewrittenEcso])
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
+
+class TestECSDisabledByRuleOrLua(TestEdnsClientSubnet):
+    """
+    dnsdist is configured to add the EDNS0 Client Subnet
+    option, but we disable it via DisableECSAction()
+    or Lua.
+    """
+
+    _config_template = """
+    setECSOverride(false)
+    setECSSourcePrefixV4(16)
+    setECSSourcePrefixV6(16)
+    newServer{address="127.0.0.1:%s", useClientSubnet=true}
+    addAction(makeRule("disabled.ecsrules.tests.powerdns.com."), DisableECSAction())
+    function disableECSViaLua(dq)
+        dq.useECS = false
+        return DNSAction.None, ""
+    end
+    addLuaAction("disabledvialua.ecsrules.tests.powerdns.com.", disableECSViaLua)
+    """
+
+    def testWithECSNotDisabled(self):
+        """
+        ECS Disable: ECS enabled in the backend
+        """
+        name = 'notdisabled.ecsrules.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 16)
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512)
+        response = dns.message.make_response(expectedQuery)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '::1')
+        response.answer.append(rrset)
+        expectedResponse.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+
+    def testWithECSDisabledViaRule(self):
+        """
+        ECS Disable: ECS enabled in the backend, but disabled by a rule
+        """
+        name = 'disabled.ecsrules.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.checkQueryNoEDNS(query, receivedQuery)
+        self.checkResponseNoEDNS(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.checkQueryNoEDNS(query, receivedQuery)
+        self.checkResponseNoEDNS(response, receivedResponse)
+
+    def testWithECSDisabledViaLua(self):
+        """
+        ECS Disable: ECS enabled in the backend, but disabled via Lua
+        """
+        name = 'disabledvialua.ecsrules.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.checkQueryNoEDNS(query, receivedQuery)
+        self.checkResponseNoEDNS(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.checkQueryNoEDNS(query, receivedQuery)
+        self.checkResponseNoEDNS(response, receivedResponse)
+
+class TestECSOverrideSetByRuleOrLua(TestEdnsClientSubnet):
+    """
+    dnsdist is configured to set the EDNS0 Client Subnet
+    option without overriding an existing one, but we
+    force the overriding via ECSOverrideAction() or Lua.
+    """
+
+    _config_template = """
+    setECSOverride(false)
+    setECSSourcePrefixV4(24)
+    setECSSourcePrefixV6(56)
+    newServer{address="127.0.0.1:%s", useClientSubnet=true}
+    addAction(makeRule("overridden.ecsrules.tests.powerdns.com."), ECSOverrideAction(true))
+    function overrideECSViaLua(dq)
+        dq.ecsOverride = true
+        return DNSAction.None, ""
+    end
+    addLuaAction("overriddenvialua.ecsrules.tests.powerdns.com.", overrideECSViaLua)
+    """
+
+    def testWithECSOverrideNotSet(self):
+        """
+        ECS Override: not set via Lua or a rule
+        """
+        name = 'notoverridden.ecsrules.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 24)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+        response = dns.message.make_response(query)
+        response.use_edns(edns=True, payload=4096, options=[ecso])
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.checkQueryEDNSWithECS(query, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.checkQueryEDNSWithECS(query, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
+
+    def testWithECSOverrideSetViaRule(self):
+        """
+        ECS Override: set with a rule
+        """
+        name = 'overridden.ecsrules.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 24)
+        rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[rewrittenEcso])
+        response = dns.message.make_response(query)
+        response.use_edns(edns=True, payload=4096, options=[rewrittenEcso])
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
+
+    def testWithECSOverrideSetViaLua(self):
+        """
+        ECS Override: set via Lua
+        """
+        name = 'overriddenvialua.ecsrules.tests.powerdns.com.'
         ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 24)
         rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso])
         expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[rewrittenEcso])
         response = dns.message.make_response(query)
+        response.use_edns(edns=True, payload=4096, options=[rewrittenEcso])
         rrset = dns.rrset.from_text(name,
                                     3600,
                                     dns.rdataclass.IN,
@@ -577,16 +803,129 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseEDNSWithECS(response, receivedResponse)
+
+class TestECSPrefixLengthSetByRuleOrLua(TestEdnsClientSubnet):
+    """
+    dnsdist is configured to set the EDNS0 Client Subnet
+    option with a prefix length of 24 for IPv4 and 56 for IPv6,
+    but we override that to 32 and 128 via ECSPrefixLengthAction() or Lua.
+    """
+
+    _config_template = """
+    setECSOverride(false)
+    setECSSourcePrefixV4(24)
+    setECSSourcePrefixV6(56)
+    newServer{address="127.0.0.1:%s", useClientSubnet=true}
+    addAction(makeRule("overriddenprefixlength.ecsrules.tests.powerdns.com."), ECSPrefixLengthAction(32, 128))
+    function overrideECSPrefixLengthViaLua(dq)
+        dq.ecsPrefixLength = 32
+        return DNSAction.None, ""
+    end
+    addLuaAction("overriddenprefixlengthvialua.ecsrules.tests.powerdns.com.", overrideECSPrefixLengthViaLua)
+    """
+
+    def testWithECSPrefixLengthNotOverridden(self):
+        """
+        ECS Prefix Length: not overridden via Lua or a rule
+        """
+        name = 'notoverriddenprefixlength.ecsrules.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512)
+        response = dns.message.make_response(query)
+        response.use_edns(edns=True, payload=4096, options=[ecso])
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+
+    def testWithECSPrefixLengthOverriddenViaRule(self):
+        """
+        ECS Prefix Length: overridden with a rule
+        """
+        name = 'overriddenprefixlength.ecsrules.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 32)
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512)
+        response = dns.message.make_response(expectedQuery)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+
+    def testWithECSPrefixLengthOverriddenViaLua(self):
+        """
+        ECS Prefix Length: overridden via Lua
+        """
+        name = 'overriddenprefixlengthvialua.ecsrules.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 32)
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512)
+        response = dns.message.make_response(expectedQuery)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = expectedQuery.id
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
         self.assertTrue(receivedQuery)
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-        self.assertEquals(receivedResponse.edns, 0)
-        self.assertEquals(len(receivedResponse.options), 0)
+        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+        self.checkResponseNoEDNS(expectedResponse, receivedResponse)