From: Remi Gacogne Date: Fri, 19 Jul 2019 14:33:10 +0000 (+0200) Subject: dnsdist: Fix ECS addition when the OPT record is not the last one X-Git-Tag: auth-4.3.0-beta2~36^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=be90d6bd464cddba82cff30c0ff27f439350e1de;p=thirdparty%2Fpdns.git dnsdist: Fix ECS addition when the OPT record is not the last one --- diff --git a/pdns/dnsdist-ecs.cc b/pdns/dnsdist-ecs.cc index 67fc573372..5b97f4d1e8 100644 --- a/pdns/dnsdist-ecs.cc +++ b/pdns/dnsdist-ecs.cc @@ -117,6 +117,7 @@ int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector pr.xfrBlob(blob); pw.xfrBlob(blob); } else { + pr.skip(ah.d_clen); } } @@ -125,6 +126,192 @@ int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector return 0; } +static bool addOrReplaceECSOption(std::vector>& options, bool& ecsAdded, bool overrideExisting, const string& newECSOption) +{ + for (auto it = options.begin(); it != options.end(); ) { + if (it->first == EDNSOptionCode::ECS) { + ecsAdded = false; + + if (!overrideExisting) { + return false; + } + + it = options.erase(it); + } + else { + ++it; + } + } + + options.emplace_back(EDNSOptionCode::ECS, std::string(&newECSOption.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newECSOption.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))); + return true; +} + +static bool slowRewriteQueryWithExistingEDNS(const std::string& initialPacket, vector& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption) +{ + assert(initialPacket.size() >= sizeof(dnsheader)); + const struct dnsheader* dh = reinterpret_cast(initialPacket.data()); + + ecsAdded = false; + ednsAdded = true; + + if (ntohs(dh->qdcount) == 0) { + return false; + } + + if (ntohs(dh->arcount) == 0) { + throw std::runtime_error("slowRewriteQueryWithExistingEDNS() should not be called for queries that have no EDNS"); + } + + PacketReader pr(initialPacket); + + size_t idx = 0; + DNSName rrname; + uint16_t qdcount = ntohs(dh->qdcount); + uint16_t ancount = ntohs(dh->ancount); + uint16_t nscount = ntohs(dh->nscount); + uint16_t arcount = ntohs(dh->arcount); + uint16_t rrtype; + uint16_t rrclass; + string blob; + struct dnsrecordheader ah; + + rrname = pr.getName(); + rrtype = pr.get16BitInt(); + rrclass = pr.get16BitInt(); + + DNSPacketWriter pw(newContent, rrname, rrtype, rrclass, dh->opcode); + pw.getHeader()->id=dh->id; + pw.getHeader()->qr=dh->qr; + pw.getHeader()->aa=dh->aa; + pw.getHeader()->tc=dh->tc; + pw.getHeader()->rd=dh->rd; + pw.getHeader()->ra=dh->ra; + pw.getHeader()->ad=dh->ad; + pw.getHeader()->cd=dh->cd; + pw.getHeader()->rcode=dh->rcode; + + /* consume remaining qd if any */ + if (qdcount > 1) { + for(idx = 1; idx < qdcount; idx++) { + rrname = pr.getName(); + rrtype = pr.get16BitInt(); + rrclass = pr.get16BitInt(); + (void) rrtype; + (void) rrclass; + } + } + + /* copy AN and NS */ + for (idx = 0; idx < ancount; idx++) { + rrname = pr.getName(); + pr.getDnsrecordheader(ah); + + pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true); + pr.xfrBlob(blob); + pw.xfrBlob(blob); + } + + for (idx = 0; idx < nscount; idx++) { + rrname = pr.getName(); + pr.getDnsrecordheader(ah); + + pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true); + pr.xfrBlob(blob); + pw.xfrBlob(blob); + } + + /* consume AR, looking for OPT */ + for (idx = 0; idx < arcount; idx++) { + rrname = pr.getName(); + pr.getDnsrecordheader(ah); + + if (ah.d_type != QType::OPT) { + pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true); + pr.xfrBlob(blob); + pw.xfrBlob(blob); + } else { + + ednsAdded = false; + pr.xfrBlob(blob); + + std::vector> options; + getEDNSOptionsFromContent(blob, options); + + EDNS0Record edns0; + static_assert(sizeof(edns0) == sizeof(ah.d_ttl), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size"); + memcpy(&edns0, &ah.d_ttl, sizeof(edns0)); + + /* addOrReplaceECSOption will set it to false if there is already an existing option */ + ecsAdded = true; + addOrReplaceECSOption(options, ecsAdded, overrideExisting, newECSOption); + pw.addOpt(ah.d_class, edns0.extRCode, edns0.extFlags, options, edns0.version); + } + } + + if (ednsAdded) { + pw.addOpt(g_EdnsUDPPayloadSize, 0, 0, {{EDNSOptionCode::ECS, std::string(&newECSOption.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newECSOption.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))}}, 0); + ecsAdded = true; + } + + pw.commit(); + + return true; +} + +static bool slowParseEDNSOptions(const char* packet, uint16_t const len, std::shared_ptr >& options) +{ + const struct dnsheader* dh = reinterpret_cast(packet); + + if (len < sizeof(dnsheader) || ntohs(dh->qdcount) == 0) + { + return false; + } + + if (ntohs(dh->arcount) == 0) { + throw std::runtime_error("slowParseEDNSOptions() should not be called for queries that have no EDNS"); + } + + try { + uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount); + DNSPacketMangler dpm(const_cast(packet), len); + uint64_t n; + for(n=0; n < ntohs(dh->qdcount) ; ++n) { + dpm.skipDomainName(); + /* type and class */ + dpm.skipBytes(4); + } + + for(n=0; n < numrecords; ++n) { + dpm.skipDomainName(); + + uint8_t section = n < ntohs(dh->ancount) ? 1 : (n < (ntohs(dh->ancount) + ntohs(dh->nscount)) ? 2 : 3); + uint16_t dnstype = dpm.get16BitInt(); + dpm.get16BitInt(); + dpm.skipBytes(4); /* TTL */ + + if(section == 3 && dnstype == QType::OPT) { + uint32_t offset = dpm.getOffset(); + if (offset >= len) { + return false; + } + /* if we survive this call, we can parse it safely */ + dpm.skipRData(); + return getEDNSOptions(packet + offset, len - offset, *options) == 0; + } + else { + dpm.skipRData(); + } + } + } + catch(...) + { + return false; + } + + return true; +} + int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * optLen, bool * last) { assert(optStart != NULL); @@ -315,6 +502,11 @@ bool parseEDNSOptions(DNSQuestion& dq) } dq.ednsOptions = std::make_shared >(); + + if (ntohs(dq.dh->arcount) != 0 && ntohs(dq.dh->arcount) != 1) { + return slowParseEDNSOptions(reinterpret_cast(dq.dh), dq.len, dq.ednsOptions); + } + const char* packet = reinterpret_cast(dq.dh); size_t remaining = 0; @@ -329,7 +521,7 @@ bool parseEDNSOptions(DNSQuestion& dq) return false; } -static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, unsigned char* optRDLen, bool* const ecsAdded) +static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, unsigned char* optRDLen, bool& ecsAdded) { /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */ /* getEDNSOptionsStart has already checked that there is exactly one AR, @@ -348,12 +540,12 @@ static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uin memcpy(packet + *len, newECSOption.c_str(), newECSOptionSize); *len += newECSOptionSize; - *ecsAdded = true; + ecsAdded = true; return true; } -static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool* const ednsAdded, bool preserveTrailingData) +static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData) { /* we need to add a EDNS0 RR with one EDNS0 ECS option, fixing the AR count */ string EDNSRR; @@ -378,27 +570,50 @@ static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t uint16_t arcount = ntohs(dh->arcount); arcount++; dh->arcount = htons(arcount); - *ednsAdded = true; + ednsAdded = true; + ecsAdded = true; memcpy(packet + realPacketLen, EDNSRR.c_str(), EDNSRR.size()); return true; } -bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData) +bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData) { assert(packet != nullptr); assert(len != nullptr); assert(consumed <= (size_t) *len); - assert(ednsAdded != nullptr); - assert(ecsAdded != nullptr); + + const struct dnsheader* dh = reinterpret_cast(packet); + + if (ntohs(dh->arcount) != 0 && ntohs(dh->arcount) != 1) { + vector newContent; + newContent.reserve(packetSize); + + if (!slowRewriteQueryWithExistingEDNS(std::string(packet, *len), newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) { + ednsAdded = false; + ecsAdded = false; + return false; + } + + if (newContent.size() > packetSize) { + ednsAdded = false; + ecsAdded = false; + return false; + } + + memcpy(packet, &newContent.at(0), newContent.size()); + *len = newContent.size(); + return true; + } + uint16_t optRDPosition = 0; size_t remaining = 0; int res = getEDNSOptionsStart(packet, consumed, *len, &optRDPosition, &remaining); if (res != 0) { - return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded, preserveTrailingData); + return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded, ecsAdded, preserveTrailingData); } unsigned char* optRDLen = reinterpret_cast(packet) + optRDPosition; @@ -422,7 +637,7 @@ bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u return true; } -bool handleEDNSClientSubnet(DNSQuestion& dq, bool* ednsAdded, bool* ecsAdded, bool preserveTrailingData) +bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData) { assert(dq.remote != nullptr); string newECSOption; diff --git a/pdns/dnsdist-ecs.hh b/pdns/dnsdist-ecs.hh index 767575723f..b339269064 100644 --- a/pdns/dnsdist-ecs.hh +++ b/pdns/dnsdist-ecs.hh @@ -38,8 +38,8 @@ bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const s bool addEDNS(dnsheader* dh, uint16_t& len, const size_t size, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode); bool addEDNSToQueryTurnedResponse(DNSQuestion& dq); -bool handleEDNSClientSubnet(DNSQuestion& dq, bool* ednsAdded, bool* ecsAdded, bool preserveTrailingData); -bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData); +bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData); +bool handleEDNSClientSubnet(char* packet, size_t packetSize, unsigned int consumed, uint16_t* len, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData); bool parseEDNSOptions(DNSQuestion& dq); diff --git a/pdns/dnsdist-lua-actions.cc b/pdns/dnsdist-lua-actions.cc index 7e16700c80..85a6108d69 100644 --- a/pdns/dnsdist-lua-actions.cc +++ b/pdns/dnsdist-lua-actions.cc @@ -183,7 +183,7 @@ DNSAction::Action TeeAction::operator()(DNSQuestion* dq, std::string* ruleresult std::string newECSOption; generateECSOption(dq->ecsSet ? dq->ecs.getNetwork() : *dq->remote, newECSOption, dq->ecsSet ? dq->ecs.getBits() : dq->ecsPrefixLength); - if (!handleEDNSClientSubnet(const_cast(query.c_str()), query.capacity(), dq->qname->wirelength(), &len, &ednsAdded, &ecsAdded, dq->ecsOverride, newECSOption, g_preserveTrailingData)) { + if (!handleEDNSClientSubnet(const_cast(query.c_str()), query.capacity(), dq->qname->wirelength(), &len, ednsAdded, ecsAdded, dq->ecsOverride, newECSOption, g_preserveTrailingData)) { return DNSAction::Action::None; } diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 1c3f54a62f..a5978cae59 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -1489,7 +1489,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& } } - if (!handleEDNSClientSubnet(dq, &(dq.ednsAdded), &(dq.ecsAdded), g_preserveTrailingData)) { + if (!handleEDNSClientSubnet(dq, dq.ednsAdded, dq.ecsAdded, g_preserveTrailingData)) { vinfolog("Dropping query from %s because we couldn't insert the ECS value", dq.remote->toStringWithPort()); return ProcessQueryResult::Drop; } diff --git a/pdns/dnsparser.cc b/pdns/dnsparser.cc index 0fd335edba..f4b5e816f4 100644 --- a/pdns/dnsparser.cc +++ b/pdns/dnsparser.cc @@ -575,122 +575,6 @@ string simpleCompress(const string& elabel, const string& root) return ret; } - -/** Simple DNSPacketMangler. Ritual is: get a pointer into the packet and moveOffset() to beyond your needs - * If you survive that, feel free to read from the pointer */ -class DNSPacketMangler -{ -public: - explicit DNSPacketMangler(std::string& packet) - : d_packet((char*) packet.c_str()), d_length(packet.length()), d_notyouroffset(12), d_offset(d_notyouroffset) - {} - DNSPacketMangler(char* packet, size_t length) - : d_packet(packet), d_length(length), d_notyouroffset(12), d_offset(d_notyouroffset) - {} - - /*! Advances past a wire-format domain name - * The name is not checked for adherence to length restrictions. - * Compression pointers are not followed. - */ - void skipDomainName() - { - uint8_t len; - while((len=get8BitInt())) { - if(len >= 0xc0) { // extended label - get8BitInt(); - return; - } - skipBytes(len); - } - } - - void skipBytes(uint16_t bytes) - { - moveOffset(bytes); - } - void rewindBytes(uint16_t by) - { - rewindOffset(by); - } - uint32_t get32BitInt() - { - const char* p = d_packet + d_offset; - moveOffset(4); - uint32_t ret; - memcpy(&ret, (void*)p, sizeof(ret)); - return ntohl(ret); - } - uint16_t get16BitInt() - { - const char* p = d_packet + d_offset; - moveOffset(2); - uint16_t ret; - memcpy(&ret, (void*)p, sizeof(ret)); - return ntohs(ret); - } - - uint8_t get8BitInt() - { - const char* p = d_packet + d_offset; - moveOffset(1); - return *p; - } - - void skipRData() - { - int toskip = get16BitInt(); - moveOffset(toskip); - } - - void decreaseAndSkip32BitInt(uint32_t decrease) - { - const char *p = d_packet + d_offset; - moveOffset(4); - - uint32_t tmp; - memcpy(&tmp, (void*) p, sizeof(tmp)); - tmp = ntohl(tmp); - tmp-=decrease; - tmp = htonl(tmp); - memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp)); - } - void setAndSkip32BitInt(uint32_t value) - { - moveOffset(4); - - value = htonl(value); - memcpy(d_packet + d_offset-4, (const char*)&value, sizeof(value)); - } - uint32_t getOffset() const - { - return d_offset; - } -private: - void moveOffset(uint16_t by) - { - d_notyouroffset += by; - if(d_notyouroffset > d_length) - throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > " - + std::to_string(d_length) ); - } - void rewindOffset(uint16_t by) - { - if(d_notyouroffset < by) - throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < " - + std::to_string(by)); - d_notyouroffset -= by; - if(d_notyouroffset < 12) - throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < " - + std::to_string(12)); - } - char* d_packet; - size_t d_length; - - uint32_t d_notyouroffset; // only 'moveOffset' can touch this - const uint32_t& d_offset; // look.. but don't touch - -}; - // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it void editDNSPacketTTL(char* packet, size_t length, std::function visitor) { diff --git a/pdns/dnsparser.hh b/pdns/dnsparser.hh index 33228bdac2..3bb22f3b7e 100644 --- a/pdns/dnsparser.hh +++ b/pdns/dnsparser.hh @@ -408,4 +408,123 @@ std::shared_ptr getRR(const DNSRecord& dr) return std::dynamic_pointer_cast(dr.d_content); } +/** Simple DNSPacketMangler. Ritual is: get a pointer into the packet and moveOffset() to beyond your needs + * If you survive that, feel free to read from the pointer */ +class DNSPacketMangler +{ +public: + explicit DNSPacketMangler(std::string& packet) + : d_packet((char*) packet.c_str()), d_length(packet.length()), d_notyouroffset(12), d_offset(d_notyouroffset) + {} + DNSPacketMangler(char* packet, size_t length) + : d_packet(packet), d_length(length), d_notyouroffset(12), d_offset(d_notyouroffset) + {} + + /*! Advances past a wire-format domain name + * The name is not checked for adherence to length restrictions. + * Compression pointers are not followed. + */ + void skipDomainName() + { + uint8_t len; + while((len=get8BitInt())) { + if(len >= 0xc0) { // extended label + get8BitInt(); + return; + } + skipBytes(len); + } + } + + void skipBytes(uint16_t bytes) + { + moveOffset(bytes); + } + void rewindBytes(uint16_t by) + { + rewindOffset(by); + } + uint32_t get32BitInt() + { + const char* p = d_packet + d_offset; + moveOffset(4); + uint32_t ret; + memcpy(&ret, (void*)p, sizeof(ret)); + return ntohl(ret); + } + uint16_t get16BitInt() + { + const char* p = d_packet + d_offset; + moveOffset(2); + uint16_t ret; + memcpy(&ret, (void*)p, sizeof(ret)); + return ntohs(ret); + } + + uint8_t get8BitInt() + { + const char* p = d_packet + d_offset; + moveOffset(1); + return *p; + } + + void skipRData() + { + int toskip = get16BitInt(); + moveOffset(toskip); + } + + void decreaseAndSkip32BitInt(uint32_t decrease) + { + const char *p = d_packet + d_offset; + moveOffset(4); + + uint32_t tmp; + memcpy(&tmp, (void*) p, sizeof(tmp)); + tmp = ntohl(tmp); + tmp-=decrease; + tmp = htonl(tmp); + memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp)); + } + + void setAndSkip32BitInt(uint32_t value) + { + moveOffset(4); + + value = htonl(value); + memcpy(d_packet + d_offset-4, (const char*)&value, sizeof(value)); + } + + uint32_t getOffset() const + { + return d_offset; + } + +private: + void moveOffset(uint16_t by) + { + d_notyouroffset += by; + if(d_notyouroffset > d_length) + throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > " + + std::to_string(d_length) ); + } + + void rewindOffset(uint16_t by) + { + if(d_notyouroffset < by) + throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < " + + std::to_string(by)); + d_notyouroffset -= by; + if(d_notyouroffset < 12) + throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < " + + std::to_string(12)); + } + + char* d_packet; + size_t d_length; + + uint32_t d_notyouroffset; // only 'moveOffset' can touch this + const uint32_t& d_offset; // look.. but don't touch +}; + #endif diff --git a/pdns/ednsoptions.cc b/pdns/ednsoptions.cc index 6b4ec1098f..d20755be73 100644 --- a/pdns/ednsoptions.cc +++ b/pdns/ednsoptions.cc @@ -105,6 +105,29 @@ int getEDNSOptions(const char* optRR, const size_t len, EDNSOptionViewMap& optio return 0; } +bool getEDNSOptionsFromContent(const std::string& content, std::vector>& options) +{ + size_t pos = 0; + uint16_t code, len; + const size_t contentLength = content.size(); + + while (pos < contentLength && (contentLength - pos) >= (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)) { + code = (static_cast(content.at(pos)) * 256) + static_cast(content.at(pos+1)); + pos += EDNS_OPTION_CODE_SIZE; + len = (static_cast(content.at(pos)) * 256) + static_cast(content.at(pos+1)); + pos += EDNS_OPTION_LENGTH_SIZE; + + if (pos > contentLength || len > (contentLength - pos)) { + return false; + } + + options.emplace_back(code, std::string(&content.at(pos), len)); + pos += len; + } + + return true; +} + void generateEDNSOption(uint16_t optionCode, const std::string& payload, std::string& res) { const uint16_t ednsOptionCode = htons(optionCode); diff --git a/pdns/ednsoptions.hh b/pdns/ednsoptions.hh index 019ac9bb9b..4c8a330ceb 100644 --- a/pdns/ednsoptions.hh +++ b/pdns/ednsoptions.hh @@ -47,6 +47,8 @@ typedef std::map EDNSOptionViewMap; /* extract all EDNS0 options from a pointer on the beginning rdLen of the OPT RR */ int getEDNSOptions(const char* optRR, size_t len, EDNSOptionViewMap& options); +/* extract all EDNS0 options from the content (so after rdLen) of the OPT RR */ +bool getEDNSOptionsFromContent(const std::string& content, std::vector>& options); void generateEDNSOption(uint16_t optionCode, const std::string& payload, std::string& res); diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc index 14d463db31..dfa035c4de 100644 --- a/pdns/test-dnsdist_cc.cc +++ b/pdns/test-dnsdist_cc.cc @@ -42,7 +42,7 @@ BOOST_AUTO_TEST_SUITE(test_dnsdist_cc) static const uint16_t ECSSourcePrefixV4 = 24; static const uint16_t ECSSourcePrefixV6 = 56; -static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=true, bool hasXPF=false) +static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=true, bool hasXPF=false, uint16_t additionals=0) { MOADNSParser mdp(true, packet, packetSize); @@ -51,7 +51,7 @@ static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=t BOOST_CHECK_EQUAL(mdp.d_header.qdcount, 1U); BOOST_CHECK_EQUAL(mdp.d_header.ancount, 0U); BOOST_CHECK_EQUAL(mdp.d_header.nscount, 0U); - uint16_t expectedARCount = 0 + (hasEdns ? 1 : 0) + (hasXPF ? 1 : 0); + uint16_t expectedARCount = additionals + (hasEdns ? 1U : 0U) + (hasXPF ? 1U : 0U); BOOST_CHECK_EQUAL(mdp.d_header.arcount, expectedARCount); } @@ -231,10 +231,10 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS) BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false)); + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false)); BOOST_CHECK(static_cast(len) > query.size()); BOOST_CHECK_EQUAL(ednsAdded, true); - BOOST_CHECK_EQUAL(ecsAdded, false); + BOOST_CHECK_EQUAL(ecsAdded, true); validateQuery(packet, len); validateECS(packet, len, remote); vector queryWithEDNS; @@ -250,7 +250,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS) BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false)); + BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false)); BOOST_CHECK_EQUAL(static_cast(len), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); @@ -273,11 +273,11 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS) packet[len + idx] = 'A'; } len += trailingDataSize; - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false)); + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false)); BOOST_REQUIRE_EQUAL(static_cast(len), queryWithEDNS.size()); BOOST_CHECK_EQUAL(memcmp(queryWithEDNS.data(), packet, queryWithEDNS.size()), 0); BOOST_CHECK_EQUAL(ednsAdded, true); - BOOST_CHECK_EQUAL(ecsAdded, false); + BOOST_CHECK_EQUAL(ecsAdded, true); validateQuery(packet, len); /* packet with trailing data (preserving trailing data) */ @@ -296,14 +296,14 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS) packet[len + idx] = 'A'; } len += trailingDataSize; - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, true)); + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, true)); BOOST_REQUIRE_EQUAL(static_cast(len), queryWithEDNS.size() + trailingDataSize); BOOST_CHECK_EQUAL(memcmp(queryWithEDNS.data(), packet, queryWithEDNS.size()), 0); for (size_t idx = 0; idx < trailingDataSize; idx++) { BOOST_CHECK_EQUAL(packet[queryWithEDNS.size() + idx], 'A'); } BOOST_CHECK_EQUAL(ednsAdded, true); - BOOST_CHECK_EQUAL(ecsAdded, false); + BOOST_CHECK_EQUAL(ecsAdded, true); validateQuery(packet, len); } @@ -335,10 +335,10 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed) BOOST_CHECK(!parseEDNSOptions(dq)); /* And now we add our own ECS */ - BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded, false)); + BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded, false)); BOOST_CHECK_GT(static_cast(dq.len), query.size()); BOOST_CHECK_EQUAL(ednsAdded, true); - BOOST_CHECK_EQUAL(ecsAdded, false); + BOOST_CHECK_EQUAL(ecsAdded, true); validateQuery(packet, dq.len); validateECS(packet, dq.len, remote); @@ -352,7 +352,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed) BOOST_CHECK(qclass == QClass::IN); DNSQuestion dq2(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast(query.data()), query.size(), query.size(), false, nullptr); - BOOST_CHECK(!handleEDNSClientSubnet(dq2, &ednsAdded, &ecsAdded, false)); + BOOST_CHECK(!handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded, false)); BOOST_CHECK_EQUAL(static_cast(dq2.len), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); @@ -384,7 +384,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false)); + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false)); BOOST_CHECK((size_t) len > query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, true); @@ -400,7 +400,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false)); + BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false)); BOOST_CHECK_EQUAL((size_t) len, query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); @@ -436,7 +436,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) { BOOST_CHECK(parseEDNSOptions(dq)); /* And now we add our own ECS */ - BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded, false)); + BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded, false)); BOOST_CHECK_GT(static_cast(dq.len), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, true); @@ -453,7 +453,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) { BOOST_CHECK(qclass == QClass::IN); DNSQuestion dq2(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast(query.data()), query.size(), query.size(), false, nullptr); - BOOST_CHECK(!handleEDNSClientSubnet(dq2, &ednsAdded, &ecsAdded, false)); + BOOST_CHECK(!handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded, false)); BOOST_CHECK_EQUAL(static_cast(dq2.len), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); @@ -491,7 +491,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false)); + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); BOOST_CHECK_EQUAL((size_t) len, query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); @@ -536,7 +536,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSizeAlreadyParsed) { BOOST_CHECK(parseEDNSOptions(dq)); /* And now we add our own ECS */ - BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded, false)); + BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded, false)); BOOST_CHECK_EQUAL(static_cast(dq.len), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); @@ -575,7 +575,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSmaller) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false)); + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); BOOST_CHECK((size_t) len < query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); @@ -614,7 +614,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false)); + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); BOOST_CHECK((size_t) len > query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); @@ -630,13 +630,227 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false)); + BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); BOOST_CHECK_EQUAL((size_t) len, query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); validateQuery(reinterpret_cast(query.data()), len); } +BOOST_AUTO_TEST_CASE(replaceECSFollowedByTSIG) { + bool ednsAdded = false; + bool ecsAdded = false; + ComboAddress remote("192.168.1.25"); + DNSName name("www.powerdns.com."); + ComboAddress origRemote("127.0.0.1"); + string newECSOption; + generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); + + vector query; + DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0); + pw.getHeader()->rd = 1; + EDNSSubnetOpts ecsOpts; + ecsOpts.source = Netmask(origRemote, 8); + string origECSOption = makeEDNSSubnetOptsString(ecsOpts); + DNSPacketWriter::optvect_t opts; + opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOption)); + pw.addOpt(512, 0, 0, opts); + pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false); + pw.commit(); + uint16_t len = query.size(); + + /* large enough packet */ + char packet[1500]; + memcpy(packet, query.data(), query.size()); + + unsigned int consumed = 0; + uint16_t qtype; + DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + BOOST_CHECK_EQUAL(qname, name); + BOOST_CHECK(qtype == QType::A); + + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); + BOOST_CHECK((size_t) len > query.size()); + BOOST_CHECK_EQUAL(ednsAdded, false); + BOOST_CHECK_EQUAL(ecsAdded, false); + validateQuery(packet, len, true, false, 1); + validateECS(packet, len, remote); + + /* not large enough packet */ + ednsAdded = false; + ecsAdded = false; + consumed = 0; + len = query.size(); + qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + BOOST_CHECK_EQUAL(qname, name); + BOOST_CHECK(qtype == QType::A); + + BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); + BOOST_CHECK_EQUAL((size_t) len, query.size()); + BOOST_CHECK_EQUAL(ednsAdded, false); + BOOST_CHECK_EQUAL(ecsAdded, false); + validateQuery(reinterpret_cast(query.data()), len, true, false, 1); +} + +BOOST_AUTO_TEST_CASE(replaceECSBetweenTwoRecords) { + bool ednsAdded = false; + bool ecsAdded = false; + ComboAddress remote("192.168.1.25"); + DNSName name("www.powerdns.com."); + ComboAddress origRemote("127.0.0.1"); + string newECSOption; + generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); + + vector query; + DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0); + pw.getHeader()->rd = 1; + EDNSSubnetOpts ecsOpts; + ecsOpts.source = Netmask(origRemote, 8); + string origECSOption = makeEDNSSubnetOptsString(ecsOpts); + DNSPacketWriter::optvect_t opts; + opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOption)); + pw.startRecord(DNSName("additional"), QType::A, 0, QClass::IN, DNSResourceRecord::ADDITIONAL, false); + pw.xfr32BitInt(0x01020304); + pw.addOpt(512, 0, 0, opts); + pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false); + pw.commit(); + uint16_t len = query.size(); + + /* large enough packet */ + char packet[1500]; + memcpy(packet, query.data(), query.size()); + + unsigned int consumed = 0; + uint16_t qtype; + DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + BOOST_CHECK_EQUAL(qname, name); + BOOST_CHECK(qtype == QType::A); + + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); + BOOST_CHECK((size_t) len > query.size()); + BOOST_CHECK_EQUAL(ednsAdded, false); + BOOST_CHECK_EQUAL(ecsAdded, false); + validateQuery(packet, len, true, false, 2); + validateECS(packet, len, remote); + + /* not large enough packet */ + ednsAdded = false; + ecsAdded = false; + consumed = 0; + len = query.size(); + qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + BOOST_CHECK_EQUAL(qname, name); + BOOST_CHECK(qtype == QType::A); + + BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); + BOOST_CHECK_EQUAL((size_t) len, query.size()); + BOOST_CHECK_EQUAL(ednsAdded, false); + BOOST_CHECK_EQUAL(ecsAdded, false); + validateQuery(reinterpret_cast(query.data()), len, true, false, 2); +} + +BOOST_AUTO_TEST_CASE(insertECSInEDNSBetweenTwoRecords) { + bool ednsAdded = false; + bool ecsAdded = false; + ComboAddress remote("192.168.1.25"); + DNSName name("www.powerdns.com."); + ComboAddress origRemote("127.0.0.1"); + string newECSOption; + generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); + + vector query; + DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0); + pw.getHeader()->rd = 1; + pw.startRecord(DNSName("additional"), QType::A, 0, QClass::IN, DNSResourceRecord::ADDITIONAL, false); + pw.xfr32BitInt(0x01020304); + pw.addOpt(512, 0, 0); + pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false); + pw.commit(); + uint16_t len = query.size(); + + /* large enough packet */ + char packet[1500]; + memcpy(packet, query.data(), query.size()); + + unsigned int consumed = 0; + uint16_t qtype; + DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + BOOST_CHECK_EQUAL(qname, name); + BOOST_CHECK(qtype == QType::A); + + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); + BOOST_CHECK((size_t) len > query.size()); + BOOST_CHECK_EQUAL(ednsAdded, false); + BOOST_CHECK_EQUAL(ecsAdded, true); + validateQuery(packet, len, true, false, 2); + validateECS(packet, len, remote); + + /* not large enough packet */ + ednsAdded = false; + ecsAdded = false; + consumed = 0; + len = query.size(); + qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + BOOST_CHECK_EQUAL(qname, name); + BOOST_CHECK(qtype == QType::A); + + BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); + BOOST_CHECK_EQUAL((size_t) len, query.size()); + BOOST_CHECK_EQUAL(ednsAdded, false); + BOOST_CHECK_EQUAL(ecsAdded, false); + validateQuery(reinterpret_cast(query.data()), len, true, false, 2); +} + +BOOST_AUTO_TEST_CASE(insertECSAfterTSIG) { + bool ednsAdded = false; + bool ecsAdded = false; + ComboAddress remote("192.168.1.25"); + DNSName name("www.powerdns.com."); + ComboAddress origRemote("127.0.0.1"); + string newECSOption; + generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6); + + vector query; + DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0); + pw.getHeader()->rd = 1; + pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false); + pw.commit(); + uint16_t len = query.size(); + + /* large enough packet */ + char packet[1500]; + memcpy(packet, query.data(), query.size()); + + unsigned int consumed = 0; + uint16_t qtype; + DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + BOOST_CHECK_EQUAL(qname, name); + BOOST_CHECK(qtype == QType::A); + + BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); + BOOST_CHECK((size_t) len > query.size()); + BOOST_CHECK_EQUAL(ednsAdded, true); + BOOST_CHECK_EQUAL(ecsAdded, true); + /* the MOADNSParser does not allow anything except XPF after a TSIG */ + BOOST_CHECK_THROW(validateQuery(packet, len, true, false, 1), MOADNSException); + validateECS(packet, len, remote); + + /* not large enough packet */ + ednsAdded = false; + ecsAdded = false; + consumed = 0; + len = query.size(); + qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + BOOST_CHECK_EQUAL(qname, name); + BOOST_CHECK(qtype == QType::A); + + BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); + BOOST_CHECK_EQUAL((size_t) len, query.size()); + BOOST_CHECK_EQUAL(ednsAdded, false); + BOOST_CHECK_EQUAL(ecsAdded, false); + validateQuery(reinterpret_cast(query.data()), len, true, false); +} + BOOST_AUTO_TEST_CASE(removeEDNSWhenFirst) { DNSName name("www.powerdns.com."); diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index f5e0b92ee6..d7c1b3b363 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -544,6 +544,9 @@ class DNSDistTest(unittest.TestCase): if withCookies: for option in received.options: self.assertEquals(option.otype, 10) + else: + for option in received.options: + self.assertNotEquals(option.otype, 10) def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0): self.assertEquals(expected, received) diff --git a/regression-tests.dnsdist/test_EdnsClientSubnet.py b/regression-tests.dnsdist/test_EdnsClientSubnet.py index 241783fba2..b01e6bc38b 100644 --- a/regression-tests.dnsdist/test_EdnsClientSubnet.py +++ b/regression-tests.dnsdist/test_EdnsClientSubnet.py @@ -204,6 +204,7 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest): ecsoResponse = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24, scope=24) response.use_edns(edns=True, payload=4096, options=[ecoResponse, ecsoResponse]) expectedResponse = dns.message.make_response(query) + expectedResponse.use_edns(edns=True, payload=4096, options=[ecoResponse]) rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN, @@ -242,6 +243,7 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest): ecsoResponse = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24, scope=24) response.use_edns(edns=True, payload=4096, options=[ecsoResponse, ecoResponse]) expectedResponse = dns.message.make_response(query, our_payload=4096) + expectedResponse.use_edns(edns=True, payload=4096, options=[ecoResponse]) rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN, @@ -280,6 +282,7 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest): ecsoResponse = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24, scope=24) response.use_edns(edns=True, payload=4096, options=[ecoResponse, ecsoResponse, ecoResponse]) expectedResponse = dns.message.make_response(query, our_payload=4096) + expectedResponse.use_edns(edns=True, payload=4096, options=[ecoResponse, ecoResponse]) rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN, @@ -482,6 +485,97 @@ class TestEdnsClientSubnetOverride(DNSDistTest): self.checkQueryEDNSWithECS(expectedQuery, receivedQuery) self.checkResponseEDNSWithECS(response, receivedResponse) + def testWithECSFollowedByAnother(self): + """ + ECS: Existing EDNS with ECS, followed by another record + + Send a query with EDNS and an existing ECS value. + The OPT record is not the last one in the query + and is followed by another one. + Check that the query received by the responder + has a valid ECS value and that the response + received from dnsdist contains an EDNS pseudo-RR. + """ + name = 'withecs-followedbyanother.ecs.tests.powerdns.com.' + ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 24) + eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef') + rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + + query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco,ecso,eco]) + # I would have loved to use a TSIG here but I can't find how to make dnspython ignore + # it while parsing the message in the receiver :-/ + query.additional.append(rrset) + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco,eco,rewrittenEcso]) + expectedQuery.additional.append(rrset) + + response = dns.message.make_response(expectedQuery) + response.use_edns(edns=True, payload=4096, options=[eco, ecso, eco]) + expectedResponse = dns.message.make_response(query) + expectedResponse.use_edns(edns=True, payload=4096, options=[eco, ecso, eco]) + response.answer.append(rrset) + response.additional.append(rrset) + expectedResponse.answer.append(rrset) + expectedResponse.additional.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery, 2) + self.checkResponseEDNSWithECS(expectedResponse, receivedResponse, 2) + + def testWithEDNSNoECSFollowedByAnother(self): + """ + ECS: Existing EDNS without ECS, followed by another record + + Send a query with EDNS but no ECS value. + The OPT record is not the last one in the query + and is followed by another one. + Check that the query received by the responder + has a valid ECS value and that the response + received from dnsdist contains an EDNS pseudo-RR. + """ + name = 'withedns-no-ecs-followedbyanother.ecs.tests.powerdns.com.' + eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef') + rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + + query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco]) + # I would have loved to use a TSIG here but I can't find how to make dnspython ignore + # it while parsing the message in the receiver :-/ + query.additional.append(rrset) + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco,rewrittenEcso]) + expectedQuery.additional.append(rrset) + + response = dns.message.make_response(expectedQuery) + response.use_edns(edns=True, payload=4096, options=[eco, rewrittenEcso, eco]) + expectedResponse = dns.message.make_response(query) + expectedResponse.use_edns(edns=True, payload=4096, options=[eco, eco]) + response.answer.append(rrset) + response.additional.append(rrset) + expectedResponse.answer.append(rrset) + expectedResponse.additional.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.checkQueryEDNSWithECS(expectedQuery, receivedQuery, 1) + self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, 2) + class TestECSDisabledByRuleOrLua(DNSDistTest): """ dnsdist is configured to add the EDNS0 Client Subnet