From: Remi Gacogne Date: Thu, 29 Sep 2016 08:48:04 +0000 (+0200) Subject: dnsdist: Allow altering the ECS behavior via rules and Lua X-Git-Tag: dnsdist-1.1.0-beta2~103^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=refs%2Fpull%2F4519%2Fhead;p=thirdparty%2Fpdns.git dnsdist: Allow altering the ECS behavior via rules and Lua --- diff --git a/pdns/README-dnsdist.md b/pdns/README-dnsdist.md index ef2246d72f..fb39794e69 100644 --- a/pdns/README-dnsdist.md +++ b/pdns/README-dnsdist.md @@ -167,6 +167,17 @@ be 192.0.2.0. This can be changed with: > setECSSourcePrefixV6(56) ``` +In addition to the global settings, rules and Lua bindings can alter this behavior per query: + +* calling `DisableECSAction()` or setting `dq.useECS` to false prevent the sending of the ECS option +* calling `ECSOverrideAction(bool)` or setting `dq.ecsOverride` will override the global `setECSOverride()` value +* calling `ECSPrefixLengthAction(v4, v6)` or setting `dq.ecsPrefixLength` will override the global +`setECSSourcePrefixV4()` and `setECSSourcePrefixV6()` values + +In effect this means that for the EDNS Client Subnet option to be added to the request, `useClientSubnet` +should be set to true for the backend used (default to false) and ECS should not have been disabled by calling +`DisableECSAction()` or setting `dq.useECS` to false (default to true). + TCP timeouts ------------ @@ -362,6 +373,7 @@ Current actions are: * Add the source MAC address to the query (MacAddrAction) * Skip the cache, if any * Log query content to a remote server (RemoteLogAction) + * Alter the EDNS Client Subnet parameters (DisableECSAction, ECSOverrideAction, ECSPrefixLengthAction) Current response actions are: @@ -415,7 +427,10 @@ A DNS rule can be: Some specific actions do not stop the processing when they match, contrary to all other actions: * Delay + * DisableECS * Disable Validation + * ECSOverride + * ECSPrefixLength * Log * MacAddr * No Recurse @@ -1312,9 +1327,12 @@ instantiate a server with additional parameters * `AllowResponseAction()`: let these packets go through * `DelayAction(milliseconds)`: delay the response by the specified amount of milliseconds (UDP-only) * `DelayResponseAction(milliseconds)`: delay the response by the specified amount of milliseconds (UDP-only) + * `DisableECSAction()`: disable the sending of ECS to the backend * `DisableValidationAction()`: set the CD bit in the question, let it go through * `DropAction()`: drop these packets * `DropResponseAction()`: drop these packets + * `ECSOverrideAction(bool)`: whether an existing ECS value should be overriden (true) or not (false) + * `ECSPrefixLengthAction(v4, v6)`: set the ECS prefix length * `LogAction([filename], [binary], [append], [buffered])`: Log a line for each query, to the specified file if any, to the console (require verbose) otherwise. When logging to a file, the `binary` optional parameter specifies whether we log in binary form (default) or in textual form, the `append` optional parameter specifies whether we open the file for appending or truncate each time (default), and the `buffered` optional parameter specifies whether writes to the file are buffered (default) or not. * `NoRecurseAction()`: strip RD bit from the question, let it go through * `PoolAction(poolname)`: set the packet into the specified pool @@ -1420,6 +1438,8 @@ instantiate a server with additional parameters * member `wirelength()`: return the length on the wire * DNSQuestion related: * member `dh`: DNSHeader + * member `ecsOverride`: whether an existing ECS value should be overriden (settable) + * member `ecsPrefixLength`: the ECS prefix length to use (settable) * member `len`: the question length * member `localaddr`: ComboAddress of the local bind this question was received on * member `opcode`: the question opcode @@ -1431,6 +1451,7 @@ instantiate a server with additional parameters * member `size`: the total size of the buffer starting at `dh` * member `skipCache`: whether to skip cache lookup / storing the answer for this question (settable) * member `tcp`: whether this question was received over a TCP socket + * member `useECS`: whether to send ECS to the backend (settable) * DNSHeader related * member `getRD()`: get recursion desired flag * member `setRD(bool)`: set recursion desired flag diff --git a/pdns/dnsdist-ecs.cc b/pdns/dnsdist-ecs.cc index d78821ed45..bab9e3fea4 100644 --- a/pdns/dnsdist-ecs.cc +++ b/pdns/dnsdist-ecs.cc @@ -29,7 +29,7 @@ /* when we add EDNS to a query, we don't want to advertise a large buffer size */ -size_t q_EdnsUDPPayloadSize = 512; +size_t g_EdnsUDPPayloadSize = 512; /* draft-ietf-dnsop-edns-client-subnet-04 "11.1. Privacy" */ uint16_t g_ECSSourcePrefixV4 = 24; uint16_t g_ECSSourcePrefixV6 = 56; @@ -231,9 +231,9 @@ static int getEDNSOptionsStart(char* packet, const size_t offset, const size_t l return 0; } -static void generateECSOption(const ComboAddress& source, string& res) +static void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength) { - Netmask sourceNetmask(source, source.sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6); + Netmask sourceNetmask(source, ECSPrefixLength); EDNSSubnetOpts ecsOpts; ecsOpts.source = sourceNetmask; string payload = makeEDNSSubnetOptsString(ecsOpts); @@ -250,7 +250,7 @@ void generateOptRR(const std::string& optRData, string& res) edns0.Z = 0; dh.d_type = htons(QType::OPT); - dh.d_class = htons(q_EdnsUDPPayloadSize); + dh.d_class = htons(g_EdnsUDPPayloadSize); static_assert(sizeof(EDNS0Record) == sizeof(dh.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)"); memcpy(&dh.d_ttl, &edns0, sizeof edns0); dh.d_clen = htons((uint16_t) optRData.length()); @@ -259,14 +259,14 @@ void generateOptRR(const std::string& optRData, string& res) res.append(optRData.c_str(), optRData.length()); } -static void replaceEDNSClientSubnetOption(char * const packet, const size_t packetSize, uint16_t * const len, string& largerPacket, const ComboAddress& remote, char * const oldEcsOptionStart, size_t const oldEcsOptionSize, unsigned char * const optRDLen) +static void replaceEDNSClientSubnetOption(char * const packet, const size_t packetSize, uint16_t * const len, string& largerPacket, const ComboAddress& remote, char * const oldEcsOptionStart, size_t const oldEcsOptionSize, unsigned char * const optRDLen, uint16_t ECSPrefixLength) { assert(packet != NULL); assert(len != NULL); assert(oldEcsOptionStart != NULL); assert(optRDLen != NULL); string ECSOption; - generateECSOption(remote, ECSOption); + generateECSOption(remote, ECSOption, ECSPrefixLength); if (ECSOption.size() == oldEcsOptionSize) { /* same size as the existing option */ @@ -309,7 +309,7 @@ static void replaceEDNSClientSubnetOption(char * const packet, const size_t pack } } -void handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, string& largerPacket, bool* const ednsAdded, bool* const ecsAdded, const ComboAddress& remote) +void handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, string& largerPacket, bool* const ednsAdded, bool* const ecsAdded, const ComboAddress& remote, bool overrideExisting, uint16_t ecsPrefixLength) { assert(packet != NULL); assert(len != NULL); @@ -318,7 +318,7 @@ void handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u assert(ecsAdded != NULL); unsigned char * optRDLen = NULL; size_t remaining = 0; - + int res = getEDNSOptionsStart(packet, consumed, *len, (char**) &optRDLen, &remaining); if (res == 0) { @@ -329,15 +329,15 @@ void handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u if (res == 0) { /* there is already an ECS value */ - if (g_ECSOverride) { - replaceEDNSClientSubnetOption(packet, packetSize, len, largerPacket, remote, ecsOptionStart, ecsOptionSize, optRDLen); + if (overrideExisting) { + replaceEDNSClientSubnetOption(packet, packetSize, len, largerPacket, remote, ecsOptionStart, ecsOptionSize, optRDLen, ecsPrefixLength); } } else { /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */ /* getEDNSOptionsStart has already checked that there is exactly one AR, no NS and no AN */ string ECSOption; - generateECSOption(remote, ECSOption); + generateECSOption(remote, ECSOption, ecsPrefixLength); const size_t ECSOptionSize = ECSOption.size(); uint16_t newRDLen = (optRDLen[0] * 256) + optRDLen[1]; @@ -366,7 +366,7 @@ void handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u string EDNSRR; struct dnsheader* dh = (struct dnsheader*) packet; string optRData; - generateECSOption(remote, optRData); + generateECSOption(remote, optRData, ecsPrefixLength); generateOptRR(optRData, EDNSRR); uint16_t arcount = ntohs(dh->arcount); arcount++; diff --git a/pdns/dnsdist-ecs.hh b/pdns/dnsdist-ecs.hh index 5f197c000d..84013eb05b 100644 --- a/pdns/dnsdist-ecs.hh +++ b/pdns/dnsdist-ecs.hh @@ -23,7 +23,7 @@ int rewriteResponseWithoutEDNS(const char * packet, size_t len, vector& newContent); int locateEDNSOptRR(char * packet, size_t len, char ** optStart, size_t * optLen, bool * last); -void handleEDNSClientSubnet(char * packet, size_t packetSize, unsigned int consumed, uint16_t * len, string& largerPacket, bool* ednsAdded, bool* ecsAdded, const ComboAddress& remote); +void handleEDNSClientSubnet(char * packet, size_t packetSize, unsigned int consumed, uint16_t * len, string& largerPacket, bool* ednsAdded, bool* ecsAdded, const ComboAddress& remote, bool overrideExisting, uint16_t ecsPrefixLength); void generateOptRR(const std::string& optRData, string& res); int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove); int rewriteResponseWithoutEDNSOption(const char * packet, const size_t len, const uint16_t optionCodeToSkip, vector& newContent); diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index fcb8a590ad..c5e9059948 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -1481,6 +1481,9 @@ vector> setupLua(bool client, const std::string& confi g_lua.registerMember("size", [](const DNSQuestion& dq) -> size_t { return dq.size; }, [](DNSQuestion& dq, size_t newSize) { (void) newSize; }); g_lua.registerMember("tcp", [](const DNSQuestion& dq) -> bool { return dq.tcp; }, [](DNSQuestion& dq, bool newTcp) { (void) newTcp; }); g_lua.registerMember("skipCache", [](const DNSQuestion& dq) -> bool { return dq.skipCache; }, [](DNSQuestion& dq, bool newSkipCache) { dq.skipCache = newSkipCache; }); + g_lua.registerMember("useECS", [](const DNSQuestion& dq) -> bool { return dq.useECS; }, [](DNSQuestion& dq, bool useECS) { dq.useECS = useECS; }); + g_lua.registerMember("ecsOverride", [](const DNSQuestion& dq) -> bool { return dq.ecsOverride; }, [](DNSQuestion& dq, bool ecsOverride) { dq.ecsOverride = ecsOverride; }); + g_lua.registerMember("ecsPrefixLength", [](const DNSQuestion& dq) -> uint16_t { return dq.ecsPrefixLength; }, [](DNSQuestion& dq, uint16_t newPrefixLength) { dq.ecsPrefixLength = newPrefixLength; }); g_lua.writeFunction("setMaxTCPClientThreads", [](uint64_t max) { if (!g_configurationDone) { diff --git a/pdns/dnsdist-lua2.cc b/pdns/dnsdist-lua2.cc index 7bd9f8826c..7afe7c0f87 100644 --- a/pdns/dnsdist-lua2.cc +++ b/pdns/dnsdist-lua2.cc @@ -743,6 +743,18 @@ void moreLua(bool client) return std::shared_ptr(new TeeAction(ComboAddress(remote, 53), addECS ? *addECS : false)); }); + g_lua.writeFunction("ECSPrefixLengthAction", [](uint16_t v4PrefixLength, uint16_t v6PrefixLength) { + return std::shared_ptr(new ECSPrefixLengthAction(v4PrefixLength, v6PrefixLength)); + }); + + g_lua.writeFunction("ECSOverrideAction", [](bool ecsOverride) { + return std::shared_ptr(new ECSOverrideAction(ecsOverride)); + }); + + g_lua.writeFunction("DisableECSAction", []() { + return std::shared_ptr(new DisableECSAction()); + }); + g_lua.registerFunction("printStats", [](const DNSAction& ta) { setLuaNoSideEffect(); auto stats = ta.getStats(); diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index e57f8064f2..21f43570d5 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -292,9 +292,9 @@ void* tcpClientThread(int pipefd) packetCache = serverPool->packetCache; } - if (ds && ds->useECS) { + if (dq.useECS && ds && ds->useECS) { uint16_t newLen = dq.len; - handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, &ecsAdded, ci.remote); + handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, &ecsAdded, ci.remote, dq.ecsOverride, dq.ecsPrefixLength); if (largerQuery.empty() == false) { query = largerQuery.c_str(); dq.len = (uint16_t) largerQuery.size(); diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index f4fcd72f33..847cdcacf3 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -1040,8 +1040,8 @@ try bool ednsAdded = false; bool ecsAdded = false; - if (ss && ss->useECS) { - handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, largerQuery, &(ednsAdded), &(ecsAdded), remote); + if (dq.useECS && ss && ss->useECS) { + handleEDNSClientSubnet(query, dq.size, consumed, &dq.len, largerQuery, &(ednsAdded), &(ecsAdded), remote, dq.ecsOverride, dq.ecsPrefixLength); } uint32_t cacheKey = 0; diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index b494e37a02..e6132997e9 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -444,9 +444,13 @@ struct DownstreamState }; using servers_t =vector>; +extern uint16_t g_ECSSourcePrefixV4; +extern uint16_t g_ECSSourcePrefixV6; +extern bool g_ECSOverride; + struct DNSQuestion { - DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t queryLen, bool isTcp): qname(name), qtype(type), qclass(class_), local(lc), remote(rem), dh(header), size(bufferSize), len(queryLen), tcp(isTcp) { } + DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t queryLen, bool isTcp): qname(name), qtype(type), qclass(class_), local(lc), remote(rem), dh(header), size(bufferSize), len(queryLen), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), tcp(isTcp), ecsOverride(g_ECSOverride) { } #ifdef HAVE_PROTOBUF boost::uuids::uuid uniqueId; @@ -459,8 +463,11 @@ struct DNSQuestion struct dnsheader* dh; size_t size; uint16_t len; + uint16_t ecsPrefixLength; const bool tcp; bool skipCache{false}; + bool ecsOverride; + bool useECS{true}; }; struct DNSResponse : DNSQuestion @@ -665,9 +672,6 @@ extern std::atomic g_configurationDone; extern uint64_t g_maxTCPClientThreads; extern uint64_t g_maxTCPQueuedConnections; extern std::atomic g_cacheCleaningDelay; -extern uint16_t g_ECSSourcePrefixV4; -extern uint16_t g_ECSSourcePrefixV6; -extern bool g_ECSOverride; extern bool g_verboseHealthChecks; extern uint32_t g_staleCacheEntriesTTL; diff --git a/pdns/dnsdistdist/dnsrulactions.cc b/pdns/dnsdistdist/dnsrulactions.cc index 8518a7197d..49292763e2 100644 --- a/pdns/dnsdistdist/dnsrulactions.cc +++ b/pdns/dnsdistdist/dnsrulactions.cc @@ -56,7 +56,7 @@ DNSAction::Action TeeAction::operator()(DNSQuestion* dq, string* ruleresult) con query.reserve(dq->size); query.assign((char*) dq->dh, len); - handleEDNSClientSubnet((char*) query.c_str(), query.size(), dq->qname->wirelength(), &len, larger, &ednsAdded, &ecsAdded, *dq->remote); + handleEDNSClientSubnet((char*) query.c_str(), query.size(), dq->qname->wirelength(), &len, larger, &ednsAdded, &ecsAdded, *dq->remote, dq->ecsOverride, dq->ecsPrefixLength); if (larger.empty()) { res = send(d_fd, query.c_str(), len, 0); diff --git a/pdns/dnsrulactions.hh b/pdns/dnsrulactions.hh index 69a7564c48..0244cd6ea9 100644 --- a/pdns/dnsrulactions.hh +++ b/pdns/dnsrulactions.hh @@ -1001,6 +1001,60 @@ public: } }; +class ECSPrefixLengthAction : public DNSAction +{ +public: + ECSPrefixLengthAction(uint16_t v4Length, uint16_t v6Length) : d_v4PrefixLength(v4Length), d_v6PrefixLength(v6Length) + { + } + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override + { + dq->ecsPrefixLength = dq->remote->sin4.sin_family == AF_INET ? d_v4PrefixLength : d_v6PrefixLength; + return Action::None; + } + string toString() const override + { + return "set ECS prefix length to " + std::to_string(d_v4PrefixLength) + "/" + std::to_string(d_v6PrefixLength); + } +private: + uint16_t d_v4PrefixLength; + uint16_t d_v6PrefixLength; +}; + +class ECSOverrideAction : public DNSAction +{ +public: + ECSOverrideAction(bool ecsOverride) : d_ecsOverride(ecsOverride) + { + } + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override + { + dq->ecsOverride = d_ecsOverride; + return Action::None; + } + string toString() const override + { + return "set ECS override to " + std::to_string(d_ecsOverride); + } +private: + bool d_ecsOverride; +}; + + +class DisableECSAction : public DNSAction +{ +public: + DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override + { + dq->useECS = false; + return Action::None; + } + string toString() const override + { + return "disable ECS"; + } +}; + class RemoteLogAction : public DNSAction, public boost::noncopyable { public: diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc index fe8a8ff360..7b17c4408d 100644 --- a/pdns/test-dnsdist_cc.cc +++ b/pdns/test-dnsdist_cc.cc @@ -43,6 +43,9 @@ bool g_console{true}; bool g_syslog{true}; bool g_verbose{true}; +static const uint16_t ECSSourcePrefixV4 = 24; +static const uint16_t ECSSourcePrefixV6 = 56; + static void validateQuery(const char * packet, size_t packetSize) { MOADNSParser mdp(packet, packetSize); @@ -91,7 +94,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS) BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote); + handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); BOOST_CHECK((size_t) len > query.size()); BOOST_CHECK_EQUAL(largerPacket.size(), 0); BOOST_CHECK_EQUAL(ednsAdded, true); @@ -105,7 +108,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS) BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote); + handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); BOOST_CHECK_EQUAL((size_t) len, query.size()); BOOST_CHECK(largerPacket.size() > query.size()); BOOST_CHECK_EQUAL(ednsAdded, true); @@ -137,7 +140,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote); + handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); BOOST_CHECK((size_t) len > query.size()); BOOST_CHECK_EQUAL(largerPacket.size(), 0); BOOST_CHECK_EQUAL(ednsAdded, false); @@ -151,7 +154,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote); + handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, false, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); BOOST_CHECK_EQUAL((size_t) len, query.size()); BOOST_CHECK(largerPacket.size() > query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); @@ -171,7 +174,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) { DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0); pw.getHeader()->rd = 1; EDNSSubnetOpts ecsOpts; - ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4); + ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4); string origECSOption = makeEDNSSubnetOptsString(ecsOpts); DNSPacketWriter::optvect_t opts; opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOption)); @@ -189,8 +192,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - g_ECSOverride = true; - handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote); + handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); BOOST_CHECK_EQUAL((size_t) len, query.size()); BOOST_CHECK_EQUAL(largerPacket.size(), 0); BOOST_CHECK_EQUAL(ednsAdded, false); @@ -228,8 +230,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSmaller) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - g_ECSOverride = true; - handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote); + handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); BOOST_CHECK((size_t) len < query.size()); BOOST_CHECK_EQUAL(largerPacket.size(), 0); BOOST_CHECK_EQUAL(ednsAdded, false); @@ -267,8 +268,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - g_ECSOverride = true; - handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote); + handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); BOOST_CHECK((size_t) len > query.size()); BOOST_CHECK_EQUAL(largerPacket.size(), 0); BOOST_CHECK_EQUAL(ednsAdded, false); @@ -282,8 +282,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - g_ECSOverride = true; - handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote); + handleEDNSClientSubnet((char*) query.data(), query.size(), consumed, &len, largerPacket, &ednsAdded, &ecsAdded, remote, true, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); BOOST_CHECK_EQUAL((size_t) len, query.size()); BOOST_CHECK(largerPacket.size() > query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); @@ -398,7 +397,7 @@ BOOST_AUTO_TEST_CASE(removeECSWhenOnlyOption) { pw.commit(); EDNSSubnetOpts ecsOpts; - ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4); + ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4); string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts); DNSPacketWriter::optvect_t opts; opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOptionStr)); @@ -445,7 +444,7 @@ BOOST_AUTO_TEST_CASE(removeECSWhenFirstOption) { pw.commit(); EDNSSubnetOpts ecsOpts; - ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4); + ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV6); string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts); EDNSCookiesOpt cookiesOpt; cookiesOpt.client = string("deadbeef"); @@ -497,7 +496,7 @@ BOOST_AUTO_TEST_CASE(removeECSWhenIntermediaryOption) { pw.commit(); EDNSSubnetOpts ecsOpts; - ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4); + ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4); string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts); EDNSCookiesOpt cookiesOpt; @@ -557,7 +556,7 @@ BOOST_AUTO_TEST_CASE(removeECSWhenLastOption) { cookiesOpt.server = string("deadbeef"); string cookiesOptionStr = makeEDNSCookiesOptString(cookiesOpt); EDNSSubnetOpts ecsOpts; - ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4); + ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4); string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts); DNSPacketWriter::optvect_t opts; opts.push_back(make_pair(EDNSOptionCode::COOKIE, cookiesOptionStr)); @@ -601,7 +600,7 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenOnlyOption) { pw.xfr32BitInt(0x01020304); EDNSSubnetOpts ecsOpts; - ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4); + ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4); string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts); DNSPacketWriter::optvect_t opts; opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOptionStr)); @@ -638,7 +637,7 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenFirstOption) { pw.xfr32BitInt(0x01020304); EDNSSubnetOpts ecsOpts; - ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4); + ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4); string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts); EDNSCookiesOpt cookiesOpt; cookiesOpt.client = string("deadbeef"); @@ -680,7 +679,7 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenIntermediaryOption) { pw.xfr32BitInt(0x01020304); EDNSSubnetOpts ecsOpts; - ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4); + ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4); string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts); EDNSCookiesOpt cookiesOpt; cookiesOpt.client = string("deadbeef"); @@ -724,7 +723,7 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenLastOption) { pw.xfr32BitInt(0x01020304); EDNSSubnetOpts ecsOpts; - ecsOpts.source = Netmask(origRemote, g_ECSSourcePrefixV4); + ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4); string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts); EDNSCookiesOpt cookiesOpt; cookiesOpt.client = string("deadbeef"); diff --git a/regression-tests.dnsdist/test_EdnsClientSubnet.py b/regression-tests.dnsdist/test_EdnsClientSubnet.py index 273a7dee56..701fe01929 100644 --- a/regression-tests.dnsdist/test_EdnsClientSubnet.py +++ b/regression-tests.dnsdist/test_EdnsClientSubnet.py @@ -4,7 +4,51 @@ import clientsubnetoption import cookiesoption from dnsdisttests import DNSDistTest -class TestEdnsClientSubnetNoOverride(DNSDistTest): +class TestEdnsClientSubnet(DNSDistTest): + def compareOptions(self, a, b): + self.assertEquals(len(a), len(b)) + for idx in xrange(len(a)): + self.assertEquals(a[idx], b[idx]) + + def checkMessageNoEDNS(self, expected, received): + self.assertEquals(expected, received) + self.assertEquals(received.edns, -1) + self.assertEquals(len(received.options), 0) + + def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0): + self.assertEquals(expected, received) + self.assertEquals(received.edns, 0) + self.assertEquals(len(received.options), withCookies) + if withCookies: + for option in received.options: + self.assertEquals(option.otype, 10) + + def checkMessageEDNSWithECS(self, expected, received): + self.assertEquals(expected, received) + self.assertEquals(received.edns, 0) + self.assertEquals(len(received.options), 1) + self.assertEquals(received.options[0].otype, clientsubnetoption.ASSIGNED_OPTION_CODE) + self.compareOptions(expected.options, received.options) + + def checkQueryEDNSWithECS(self, expected, received): + self.checkMessageEDNSWithECS(expected, received) + + def checkResponseEDNSWithECS(self, expected, received): + self.checkMessageEDNSWithECS(expected, received) + + def checkQueryEDNSWithoutECS(self, expected, received): + self.checkMessageEDNSWithoutECS(expected, received) + + def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0): + self.checkMessageEDNSWithoutECS(expected, received, withCookies) + + def checkQueryNoEDNS(self, expected, received): + self.checkMessageNoEDNS(expected, received) + + def checkResponseNoEDNS(self, expected, received): + self.checkMessageNoEDNS(expected, received) + +class TestEdnsClientSubnetNoOverride(TestEdnsClientSubnet): """ dnsdist is configured to add the EDNS0 Client Subnet option, but only if it's not already present in the @@ -43,19 +87,15 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, -1) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseNoEDNS(expectedResponse, receivedResponse) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, -1) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseNoEDNS(expectedResponse, receivedResponse) def testWithEDNSNoECS(self): """ @@ -84,19 +124,15 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse) def testWithEDNSECS(self): """ @@ -119,23 +155,20 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest): '127.0.0.1') response.answer.append(rrset) + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - self.assertEquals(query, receivedQuery) - self.assertEquals(response, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(query, receivedQuery) + self.checkResponseEDNSWithoutECS(response, receivedResponse) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = query.id - self.assertEquals(query, receivedQuery) - self.assertEquals(response, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(query, receivedQuery) + self.checkResponseEDNSWithoutECS(response, receivedResponse) def testWithoutEDNSResponseWithECS(self): """ @@ -168,19 +201,15 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, -1) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseNoEDNS(expectedResponse, receivedResponse) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, -1) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseNoEDNS(expectedResponse, receivedResponse) def testWithEDNSNoECSResponseWithECS(self): """ @@ -213,19 +242,15 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse) def testWithEDNSNoECSResponseWithCookiesThenECS(self): """ @@ -254,24 +279,21 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest): '127.0.0.1') response.answer.append(rrset) expectedResponse.answer.append(rrset) + expectedResponse.use_edns(edns=True, payload=4096, options=[ecoResponse]) (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 1) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 1) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1) def testWithEDNSNoECSResponseWithECSThenCookies(self): """ @@ -300,24 +322,21 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest): '127.0.0.1') response.answer.append(rrset) expectedResponse.answer.append(rrset) + response.use_edns(edns=True, payload=4096, options=[ecoResponse]) (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 1) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 1) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1) def testWithEDNSNoECSResponseWithCookiesThenECSThenCookies(self): """ @@ -351,22 +370,18 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 2) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=2) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 2) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=2) -class TestEdnsClientSubnetOverride(DNSDistTest): +class TestEdnsClientSubnetOverride(TestEdnsClientSubnet): """ dnsdist is configured to add the EDNS0 Client Subnet option, overwriting any existing value. @@ -389,37 +404,34 @@ class TestEdnsClientSubnetOverride(DNSDistTest): and that the response received from dnsdist does not have an EDNS pseudo-RR. """ - name = 'withoutedns.overriden.ecs.tests.powerdns.com.' + name = 'withoutedns.overridden.ecs.tests.powerdns.com.' ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24) query = dns.message.make_query(name, 'A', 'IN') expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512) response = dns.message.make_response(expectedQuery) - expectedResponse = dns.message.make_response(query) + response.use_edns(edns=True, payload=4096, options=[ecso]) rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN, dns.rdatatype.A, '127.0.0.1') response.answer.append(rrset) + expectedResponse = dns.message.make_response(query) expectedResponse.answer.append(rrset) (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, -1) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseNoEDNS(expectedResponse, receivedResponse) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, -1) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseNoEDNS(expectedResponse, receivedResponse) def testWithEDNSNoECS(self): """ @@ -430,37 +442,34 @@ class TestEdnsClientSubnetOverride(DNSDistTest): has a valid ECS value and that the response received from dnsdist contains an EDNS pseudo-RR. """ - name = 'withednsnoecs.overriden.ecs.tests.powerdns.com.' + name = 'withednsnoecs.overridden.ecs.tests.powerdns.com.' ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24) query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096) expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso]) response = dns.message.make_response(expectedQuery) - expectedResponse = dns.message.make_response(query) + response.use_edns(edns=True, payload=4096, options=[ecso]) rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN, dns.rdatatype.A, '127.0.0.1') response.answer.append(rrset) + expectedResponse = dns.message.make_response(query) expectedResponse.answer.append(rrset) (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(expectedResponse, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse) def testWithEDNSShorterInitialECS(self): """ @@ -471,15 +480,16 @@ class TestEdnsClientSubnetOverride(DNSDistTest): has an overwritten ECS value (not the initial one) and that the response received from dnsdist contains an EDNS pseudo-RR. - The initial ECS value is shorter than the one it will + The initial ECS value is shorter than the one it will be replaced with. """ - name = 'withednsecs.overriden.ecs.tests.powerdns.com.' + name = 'withednsecs.overridden.ecs.tests.powerdns.com.' ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 8) rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24) query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso]) expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[rewrittenEcso]) response = dns.message.make_response(query) + response.use_edns(edns=True, payload=4096, options=[rewrittenEcso]) rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN, @@ -491,19 +501,15 @@ class TestEdnsClientSubnetOverride(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(response, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithECS(response, receivedResponse) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(response, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithECS(response, receivedResponse) def testWithEDNSLongerInitialECS(self): """ @@ -517,12 +523,13 @@ class TestEdnsClientSubnetOverride(DNSDistTest): The initial ECS value is longer than the one it will replaced with. """ - name = 'withednsecs.overriden.ecs.tests.powerdns.com.' + name = 'withednsecs.overridden.ecs.tests.powerdns.com.' ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 32) rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24) query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso]) expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[rewrittenEcso]) response = dns.message.make_response(query) + response.use_edns(edns=True, payload=4096, options=[rewrittenEcso]) rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN, @@ -534,19 +541,15 @@ class TestEdnsClientSubnetOverride(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(response, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithECS(response, receivedResponse) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(response, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithECS(response, receivedResponse) def testWithEDNSSameSizeInitialECS(self): """ @@ -560,12 +563,235 @@ class TestEdnsClientSubnetOverride(DNSDistTest): The initial ECS value is exactly the same size as the one it will replaced with. """ - name = 'withednsecs.overriden.ecs.tests.powerdns.com.' + name = 'withednsecs.overridden.ecs.tests.powerdns.com.' + ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 24) + rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24) + query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso]) + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[rewrittenEcso]) + response = dns.message.make_response(query) + response.use_edns(edns=True, payload=4096, options=[rewrittenEcso]) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithECS(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithECS(response, receivedResponse) + +class TestECSDisabledByRuleOrLua(TestEdnsClientSubnet): + """ + dnsdist is configured to add the EDNS0 Client Subnet + option, but we disable it via DisableECSAction() + or Lua. + """ + + _config_template = """ + setECSOverride(false) + setECSSourcePrefixV4(16) + setECSSourcePrefixV6(16) + newServer{address="127.0.0.1:%s", useClientSubnet=true} + addAction(makeRule("disabled.ecsrules.tests.powerdns.com."), DisableECSAction()) + function disableECSViaLua(dq) + dq.useECS = false + return DNSAction.None, "" + end + addLuaAction("disabledvialua.ecsrules.tests.powerdns.com.", disableECSViaLua) + """ + + def testWithECSNotDisabled(self): + """ + ECS Disable: ECS enabled in the backend + """ + name = 'notdisabled.ecsrules.tests.powerdns.com.' + ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 16) + query = dns.message.make_query(name, 'A', 'IN') + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512) + response = dns.message.make_response(expectedQuery) + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.AAAA, + '::1') + response.answer.append(rrset) + expectedResponse.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseNoEDNS(expectedResponse, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseNoEDNS(expectedResponse, receivedResponse) + + def testWithECSDisabledViaRule(self): + """ + ECS Disable: ECS enabled in the backend, but disabled by a rule + """ + name = 'disabled.ecsrules.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.checkQueryNoEDNS(query, receivedQuery) + self.checkResponseNoEDNS(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.checkQueryNoEDNS(query, receivedQuery) + self.checkResponseNoEDNS(response, receivedResponse) + + def testWithECSDisabledViaLua(self): + """ + ECS Disable: ECS enabled in the backend, but disabled via Lua + """ + name = 'disabledvialua.ecsrules.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.checkQueryNoEDNS(query, receivedQuery) + self.checkResponseNoEDNS(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.checkQueryNoEDNS(query, receivedQuery) + self.checkResponseNoEDNS(response, receivedResponse) + +class TestECSOverrideSetByRuleOrLua(TestEdnsClientSubnet): + """ + dnsdist is configured to set the EDNS0 Client Subnet + option without overriding an existing one, but we + force the overriding via ECSOverrideAction() or Lua. + """ + + _config_template = """ + setECSOverride(false) + setECSSourcePrefixV4(24) + setECSSourcePrefixV6(56) + newServer{address="127.0.0.1:%s", useClientSubnet=true} + addAction(makeRule("overridden.ecsrules.tests.powerdns.com."), ECSOverrideAction(true)) + function overrideECSViaLua(dq) + dq.ecsOverride = true + return DNSAction.None, "" + end + addLuaAction("overriddenvialua.ecsrules.tests.powerdns.com.", overrideECSViaLua) + """ + + def testWithECSOverrideNotSet(self): + """ + ECS Override: not set via Lua or a rule + """ + name = 'notoverridden.ecsrules.tests.powerdns.com.' + ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 24) + query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso]) + response = dns.message.make_response(query) + response.use_edns(edns=True, payload=4096, options=[ecso]) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.checkQueryEDNSWithECS(query, receivedQuery) + self.checkResponseEDNSWithECS(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.checkQueryEDNSWithECS(query, receivedQuery) + self.checkResponseEDNSWithECS(response, receivedResponse) + + def testWithECSOverrideSetViaRule(self): + """ + ECS Override: set with a rule + """ + name = 'overridden.ecsrules.tests.powerdns.com.' + ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 24) + rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24) + query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso]) + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[rewrittenEcso]) + response = dns.message.make_response(query) + response.use_edns(edns=True, payload=4096, options=[rewrittenEcso]) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithECS(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithECS(response, receivedResponse) + + def testWithECSOverrideSetViaLua(self): + """ + ECS Override: set via Lua + """ + name = 'overriddenvialua.ecsrules.tests.powerdns.com.' ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 24) rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24) query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso]) expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[rewrittenEcso]) response = dns.message.make_response(query) + response.use_edns(edns=True, payload=4096, options=[rewrittenEcso]) rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN, @@ -577,16 +803,129 @@ class TestEdnsClientSubnetOverride(DNSDistTest): self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(response, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithECS(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseEDNSWithECS(response, receivedResponse) + +class TestECSPrefixLengthSetByRuleOrLua(TestEdnsClientSubnet): + """ + dnsdist is configured to set the EDNS0 Client Subnet + option with a prefix length of 24 for IPv4 and 56 for IPv6, + but we override that to 32 and 128 via ECSPrefixLengthAction() or Lua. + """ + + _config_template = """ + setECSOverride(false) + setECSSourcePrefixV4(24) + setECSSourcePrefixV6(56) + newServer{address="127.0.0.1:%s", useClientSubnet=true} + addAction(makeRule("overriddenprefixlength.ecsrules.tests.powerdns.com."), ECSPrefixLengthAction(32, 128)) + function overrideECSPrefixLengthViaLua(dq) + dq.ecsPrefixLength = 32 + return DNSAction.None, "" + end + addLuaAction("overriddenprefixlengthvialua.ecsrules.tests.powerdns.com.", overrideECSPrefixLengthViaLua) + """ + + def testWithECSPrefixLengthNotOverridden(self): + """ + ECS Prefix Length: not overridden via Lua or a rule + """ + name = 'notoverriddenprefixlength.ecsrules.tests.powerdns.com.' + ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24) + query = dns.message.make_query(name, 'A', 'IN') + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512) + response = dns.message.make_response(query) + response.use_edns(edns=True, payload=4096, options=[ecso]) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + expectedResponse = dns.message.make_response(query) + expectedResponse.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseNoEDNS(expectedResponse, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseNoEDNS(expectedResponse, receivedResponse) + + def testWithECSPrefixLengthOverriddenViaRule(self): + """ + ECS Prefix Length: overridden with a rule + """ + name = 'overriddenprefixlength.ecsrules.tests.powerdns.com.' + ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 32) + query = dns.message.make_query(name, 'A', 'IN') + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512) + response = dns.message.make_response(expectedQuery) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + expectedResponse = dns.message.make_response(query) + expectedResponse.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseNoEDNS(expectedResponse, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseNoEDNS(expectedResponse, receivedResponse) + + def testWithECSPrefixLengthOverriddenViaLua(self): + """ + ECS Prefix Length: overridden via Lua + """ + name = 'overriddenprefixlengthvialua.ecsrules.tests.powerdns.com.' + ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 32) + query = dns.message.make_query(name, 'A', 'IN') + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512) + response = dns.message.make_response(expectedQuery) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + expectedResponse = dns.message.make_response(query) + expectedResponse.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseNoEDNS(expectedResponse, receivedResponse) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) self.assertTrue(receivedQuery) self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id - self.assertEquals(expectedQuery, receivedQuery) - self.assertEquals(response, receivedResponse) - self.assertEquals(receivedResponse.edns, 0) - self.assertEquals(len(receivedResponse.options), 0) + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) + self.checkResponseNoEDNS(expectedResponse, receivedResponse)