From bb09e03d2803c111a7fe1f474885652113e3c959 Mon Sep 17 00:00:00 2001 From: Charles-Henri Bruyand Date: Fri, 10 Dec 2021 16:59:27 +0100 Subject: [PATCH] dnsdist: add non ffi interface to clear given record types in a response --- pdns/dnsdist-lua-actions.cc | 39 +++++ pdns/dnsdistdist/docs/rules-actions.rst | 8 + pdns/dnsparser.cc | 171 ++++++++++++++------- pdns/test-dnsparser_cc.cc | 8 - regression-tests.dnsdist/test_Responses.py | 31 +++- 5 files changed, 194 insertions(+), 63 deletions(-) diff --git a/pdns/dnsdist-lua-actions.cc b/pdns/dnsdist-lua-actions.cc index a75ba39300..0598d222a4 100644 --- a/pdns/dnsdist-lua-actions.cc +++ b/pdns/dnsdist-lua-actions.cc @@ -1720,6 +1720,32 @@ private: std::string d_value; }; +class ClearRecordTypesResponseAction : public DNSResponseAction, public boost::noncopyable +{ +public: + ClearRecordTypesResponseAction() {} + + ClearRecordTypesResponseAction(std::set qtypes) : d_qtypes(qtypes) + { + } + + DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + { + if (d_qtypes.size() > 0) { + clearDNSPacketRecordTypes(dr->getMutableData(), d_qtypes); + } + return DNSResponseAction::Action::HeaderModify; + } + + std::string toString() const override + { + return "clear record types"; + } + +private: + std::set d_qtypes{}; +}; + class ContinueAction : public DNSAction { public: @@ -2234,6 +2260,19 @@ void setupLuaActions(LuaContext& luaCtx) return std::shared_ptr(new LimitTTLResponseAction(0, max)); }); + luaCtx.writeFunction("ClearRecordTypesResponseAction", [](boost::variant>> types) { + std::set qtypes{}; + if(auto t = boost::get(types)) { + qtypes.insert(t); + } else { + const auto& v = boost::get>>(types); + for(const auto& tpair: v) { + qtypes.insert(tpair.second); + } + } + return std::shared_ptr(new ClearRecordTypesResponseAction(qtypes)); + }); + luaCtx.writeFunction("RCodeAction", [](uint8_t rcode, boost::optional vars) { auto ret = std::shared_ptr(new RCodeAction(rcode)); auto rca = std::dynamic_pointer_cast(ret); diff --git a/pdns/dnsdistdist/docs/rules-actions.rst b/pdns/dnsdistdist/docs/rules-actions.rst index 997d785a6b..582a2505ca 100644 --- a/pdns/dnsdistdist/docs/rules-actions.rst +++ b/pdns/dnsdistdist/docs/rules-actions.rst @@ -838,6 +838,14 @@ The following actions exist. Let these packets go through. +.. function::ClearRecordTypesResponseAction(types) + + .. versionadded:: 1.8.0 + + Removes given type(s) records from the response. + + :param int types: a single type or a list of types to remove + .. function:: ContinueAction(action) .. versionadded:: 1.4.0 diff --git a/pdns/dnsparser.cc b/pdns/dnsparser.cc index a2a6634dbf..b592205996 100644 --- a/pdns/dnsparser.cc +++ b/pdns/dnsparser.cc @@ -745,69 +745,132 @@ void editDNSPacketTTL(char* packet, size_t length, const std::function& packet, const std::set& qtypes) -{ - size_t finalsize = packet.size(); - clearDNSPacketRecordTypes(reinterpret_cast(packet.data()), finalsize, qtypes); - packet.resize(finalsize); -} - -void clearDNSPacketRecordTypes(PacketBuffer& packet, const std::set& qtypes) +int rewritePacketWithoutRecordTypes(const PacketBuffer& initialPacket, PacketBuffer& newContent, const std::set& qtypes) { - size_t finalsize = packet.size(); - clearDNSPacketRecordTypes(reinterpret_cast(packet.data()), finalsize, qtypes); - packet.resize(finalsize); -} + static const std::set& safeTypes{QType::A, QType::AAAA, QType::DHCID, QType::TXT, QType::LUA, QType::ENT, QType::OPT, QType::HINFO, QType::DNSKEY, QType::CDNSKEY, QType::DS, QType::CDS, QType::DLV, QType::SSHFP, QType::KEY, QType::CERT, QType::TLSA, QType::SMIMEA, QType::OPENPGPKEY, QType::SVCB, QType::HTTPS, QType::NSEC3, QType::CSYNC, QType::NSEC3PARAM, QType::LOC, QType::NID, QType::L32, QType::L64, QType::EUI48, QType::EUI64, QType::URI, QType::CAA}; -// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it -void clearDNSPacketRecordTypes(char* packet, size_t& length, const std::set& qtypes) -{ - if (length < sizeof(dnsheader)) { - return; + if (initialPacket.size() < sizeof(dnsheader)) { + return EINVAL; } - try - { - dnsheader* dh = reinterpret_cast(packet); - uint64_t ancount = ntohs(dh->ancount); - uint64_t nscount = ntohs(dh->nscount); - uint64_t arcount = ntohs(dh->arcount); - DNSPacketMangler dpm(packet, length); + try { + const struct dnsheader* dh = reinterpret_cast(initialPacket.data()); + + if (ntohs(dh->qdcount) == 0) + return ENOENT; + auto packetView = pdns_string_view(reinterpret_cast(initialPacket.data()), initialPacket.size()); + + PacketReader pr(packetView); + + size_t idx = 0; + DNSName rrname; + uint16_t qdcount = ntohs(dh->qdcount); + uint16_t ancount = ntohs(dh->ancount); + uint16_t nscount = ntohs(dh->nscount); + uint16_t arcount = ntohs(dh->arcount); + uint16_t rrtype; + uint16_t rrclass; + string blob; + struct dnsrecordheader ah; - for(uint64_t n = 0; n < ntohs(dh->qdcount) ; ++n) { - dpm.skipDomainName(); - /* type and class */ - dpm.skipBytes(4); + rrname = pr.getName(); + rrtype = pr.get16BitInt(); + rrclass = pr.get16BitInt(); + + GenericDNSPacketWriter pw(newContent, rrname, rrtype, rrclass, dh->opcode); + pw.getHeader()->id=dh->id; + pw.getHeader()->qr=dh->qr; + pw.getHeader()->aa=dh->aa; + pw.getHeader()->tc=dh->tc; + pw.getHeader()->rd=dh->rd; + pw.getHeader()->ra=dh->ra; + pw.getHeader()->ad=dh->ad; + pw.getHeader()->cd=dh->cd; + pw.getHeader()->rcode=dh->rcode; + + /* consume remaining qd if any */ + if (qdcount > 1) { + for(idx = 1; idx < qdcount; idx++) { + rrname = pr.getName(); + rrtype = pr.get16BitInt(); + rrclass = pr.get16BitInt(); + (void) rrtype; + (void) rrclass; + } + } + + /* copy AN */ + for (idx = 0; idx < ancount; idx++) { + rrname = pr.getName(); + pr.getDnsrecordheader(ah); + pr.xfrBlob(blob); + + if (qtypes.find(ah.d_type) == qtypes.end()) { + // if this is not a safe type + if (safeTypes.find(ah.d_type) == safeTypes.end()) { + // "unsafe" types might countain compressed data, so cancel rewrite + newContent.clear(); + return EIO; + } + pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true); + pw.xfrBlob(blob); + } + } + + /* copy NS */ + for (idx = 0; idx < nscount; idx++) { + rrname = pr.getName(); + pr.getDnsrecordheader(ah); + pr.xfrBlob(blob); + + if (qtypes.find(ah.d_type) == qtypes.end()) { + if (safeTypes.find(ah.d_type) == safeTypes.end()) { + // "unsafe" types might countain compressed data, so cancel rewrite + newContent.clear(); + return EIO; + } + pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true); + pw.xfrBlob(blob); + } } - auto processSection = [&dpm, &qtypes, &length](uint64_t& count) { - for (uint64_t n=0; n < count; ++n) { - uint32_t start = dpm.getOffset(); - dpm.skipDomainName(); - - const auto dnstype = QType{dpm.get16BitInt()}; - /* class, ttl */ - dpm.skipBytes(6); - dpm.skipRData(); - if (qtypes.find(dnstype) != qtypes.end()) { - uint32_t rlen = dpm.getOffset() - start; - dpm.rewindBytes(rlen); - dpm.shrinkBytes(rlen); - // update count - count -= 1; - length -= rlen; - n -= 1; + /* copy AR */ + for (idx = 0; idx < arcount; idx++) { + rrname = pr.getName(); + pr.getDnsrecordheader(ah); + pr.xfrBlob(blob); + + if (qtypes.find(ah.d_type) == qtypes.end()) { + if (safeTypes.find(ah.d_type) == safeTypes.end()) { + // "unsafe" types might countain compressed data, so cancel rewrite + newContent.clear(); + return EIO; } + pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true); + pw.xfrBlob(blob); } - }; - processSection(ancount); - dh->ancount = htons(ancount); - processSection(nscount); - dh->nscount = htons(nscount); - processSection(arcount); - dh->arcount = htons(arcount); + } + pw.commit(); + } - catch(...) + catch (...) { - return; + newContent.clear(); + return EIO; + } + return 0; +} + +void clearDNSPacketRecordTypes(vector& packet, const std::set& qtypes) +{ + return clearDNSPacketRecordTypes(reinterpret_cast(packet), qtypes); +} + +void clearDNSPacketRecordTypes(PacketBuffer& packet, const std::set& qtypes) +{ + PacketBuffer newContent; + + auto result = rewritePacketWithoutRecordTypes(packet, newContent, qtypes); + if (!result) { + packet = std::move(newContent); } } diff --git a/pdns/test-dnsparser_cc.cc b/pdns/test-dnsparser_cc.cc index f24bb26b59..2ec3802a5e 100644 --- a/pdns/test-dnsparser_cc.cc +++ b/pdns/test-dnsparser_cc.cc @@ -504,9 +504,6 @@ BOOST_AUTO_TEST_CASE(test_clearDNSPacketRecordTypes) { pwR.xfrIP6(std::string(reinterpret_cast(v6.sin6.sin6_addr.s6_addr), 16)); pwR.commit(); - pwR.startRecord(name, QType::SOA, 257, QClass::IN, DNSResourceRecord::AUTHORITY); - pwR.commit(); - pwR.startRecord(name, QType::A, 256, QClass::IN, DNSResourceRecord::ADDITIONAL); pwR.xfrIP(v4.sin4.sin_addr.s_addr); pwR.commit(); @@ -520,7 +517,6 @@ BOOST_AUTO_TEST_CASE(test_clearDNSPacketRecordTypes) { BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::A), 1); BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::AAAA), 1); - BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 2, QType::SOA), 1); BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 3, QType::A), 1); std::set toremove{QType::AAAA}; @@ -528,7 +524,6 @@ BOOST_AUTO_TEST_CASE(test_clearDNSPacketRecordTypes) { BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::A), 1); BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::AAAA), 0); - BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 2, QType::SOA), 1); BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 3, QType::A), 1); toremove = {QType::A}; @@ -536,14 +531,12 @@ BOOST_AUTO_TEST_CASE(test_clearDNSPacketRecordTypes) { BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::A), 0); BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::AAAA), 0); - BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 2, QType::SOA), 1); BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 3, QType::A), 0); packet = generatePacket(); BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::A), 1); BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::AAAA), 1); - BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 2, QType::SOA), 1); BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 3, QType::A), 1); toremove = {QType::A, QType::AAAA}; @@ -551,7 +544,6 @@ BOOST_AUTO_TEST_CASE(test_clearDNSPacketRecordTypes) { BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::A), 0); BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::AAAA), 0); - BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 2, QType::SOA), 1); BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 3, QType::A), 0); } diff --git a/regression-tests.dnsdist/test_Responses.py b/regression-tests.dnsdist/test_Responses.py index cd76fbe815..301048b164 100644 --- a/regression-tests.dnsdist/test_Responses.py +++ b/regression-tests.dnsdist/test_Responses.py @@ -414,7 +414,8 @@ class TestResponseClearRecordsType(DNSDistTest): newServer{address="127.0.0.1:%s"} - addResponseAction("ffi.clear-records-type.responses.tests.powerdns.com.", LuaResponseAction(luafct)) + addResponseAction("ffi.clear-records-type.responses.tests.powerdns.com.", LuaFFIResponseAction(luafct)) + addResponseAction("clear-records-type.responses.tests.powerdns.com.", ClearRecordTypesResponseAction(DNSQType.AAAA)) """ def testClearedFFI(self): @@ -444,3 +445,31 @@ class TestResponseClearRecordsType(DNSDistTest): receivedQuery.id = query.id self.assertEqual(query, receivedQuery) self.assertEqual(expectedResponse, receivedResponse) + + def testCleared(self): + """ + Responses: Removes records of a given type + """ + name = 'clear-records-type.responses.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.0.2.1') + response.answer.append(rrset) + expectedResponse.answer.append(rrset) + rrset = dns.rrset.from_text(name, + 3660, + dns.rdataclass.IN, + dns.rdatatype.AAAA, + '2001:DB8::1', '2001:DB8::2') + response.answer.append(rrset) + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(expectedResponse, receivedResponse) -- 2.47.2